#include <map> 
#include <string> 
#include <vector> 
#include <mutex> 
#include "opencl_source_map.hpp" 
namespace MNN { 
std::mutex gCLMutex;
const char* conv_2d = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define READ_INPUT_IMAGE(i, base) "" int in_width_value##i = in_width##i + base; "" in_width_value##i = "" select(in_idx + in_width_value##i, -1, (in_width_value##i < 0 || in_width_value##i >= input_shape.y)); "" in##i=RI_F(input,SAMPLER,(int2)(in_width_value##i,in_hb_value));\n"
"#define CALCULATE_OUTPUT(i) "" out##i = mad(in##i.x, weights0, out##i); "" out##i = mad(in##i.y, weights1, out##i); "" out##i = mad(in##i.z, weights2, out##i); "" out##i=mad(in##i.w,weights3,out##i); \n"
"#define CALCULATE_OUTPUT_WEIGHTS4(i, j) "" out##i = mad(in##j.x, weights4, out##i); "" out##i = mad(in##j.y, weights5, out##i); "" out##i = mad(in##j.z, weights6, out##i); "" out##i=mad(in##j.w,weights7,out##i);\n"
"#define CALCULATE_OUTPUT_OPT(i) "" out##i = mad(in_sm##i[local_idx].x, weights0, out##i); "" out##i = mad(in_sm##i[local_idx].y, weights1, out##i); "" out##i = mad(in_sm##i[local_idx].z, weights2, out##i); "" out##i=mad(in_sm##i[local_idx].w,weights3,out##i); \n"
"#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0,__private const int global_size_dim1,\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n"
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"#define UNIT 4\n"
"#define MOD_NUM 15\n"
"#ifdef INPUT_CHANNEL_LEAVE\n"
" #define PADZEROSVEC(k, channel, data0, data1, data2, data3) "" data0 = (k << 2) < channel ? data0 : 0; "" data1 = (k << 2) + 1 < channel ? data1 : 0; "" data2 = (k << 2) + 2 < channel ? data2 : 0; "" data3=(k << 2)+3<channel ? data3 : 0;\n"
"#else\n"
" #define PADZEROSVEC(k,channel,data0,data1,data2,data3)\n"
"#endif\n"
"__kernel\n"
"#if SET_ATTRIBUTE\n"
"__attribute__((work_group_size_hint(16,16,1)))\n"
"#endif\n"
"void conv_2d_1x1_mali(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,__read_only image2d_t input,\n"
" #ifdef BUFFER_INP_FP32\n"
" __global const float *kernel_ptr,\n"
" __global const float *bias_ptr,\n"
" #else\n"
" __global const FLOAT *kernel_ptr,\n"
" __global const FLOAT *bias_ptr,\n"
" #endif\n"
" __write_only image2d_t output,\n"
" __private const int in_c_block,__private const int out_h,\n"
" __private const int out_w) {\n"
" const int out_c_w_idx=get_global_id(0); //c/4 w\n"
" const int out_b_h_idx=get_global_id(1); //b h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int out_c_idx=out_c_w_idx/out_w_blocks;\n"
" const int out_w_idx=out_c_w_idx % out_w_blocks;\n"
" const int out_w4_idx=mul24(out_w_idx,4);\n"
" #ifdef BUFFER_INP_FP32\n"
" FLOAT4 out0=CONVERT_FLOAT4(vload4(out_c_idx,(__global float *)bias_ptr));\n"
" #else\n"
" FLOAT4 out0=vload4(out_c_idx,(__global FLOAT *)bias_ptr);\n"
" #endif\n"
" FLOAT4 out1=out0;\n"
" FLOAT4 out2=out0;\n"
" FLOAT4 out3=out0;\n"
" FLOAT4 weights0;\n"
" FLOAT4 weights1;\n"
" FLOAT4 weights2;\n"
" FLOAT4 weights3;\n"
" FLOAT4 in0; \n"
" FLOAT4 in1; \n"
" FLOAT4 in2;\n"
" FLOAT4 in3; \n"
" FLOAT16 weight16;\n"
" const int intput_width_idx0=out_w4_idx;\n"
" const int intput_width_idx1=out_w4_idx+1;\n"
" const int intput_width_idx2=out_w4_idx+2;\n"
" const int intput_width_idx3=out_w4_idx+3;\n"
" for (int in_channel_block_idx=0; in_channel_block_idx<in_c_block; ++in_channel_block_idx) {\n"
" int input_width_base=mul24(in_channel_block_idx,out_w);\n"
" int offset=mad24(out_c_idx,in_c_block,in_channel_block_idx)*4;\n"
" in0=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx0,out_b_h_idx));\n"
" in1=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx1,out_b_h_idx));\n"
" in2=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx2,out_b_h_idx));\n"
" in3=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx3,out_b_h_idx));\n"
" #ifdef BUFFER_INP_FP32\n"
" weights0=CONVERT_FLOAT4(vload4(offset,(__global float *)kernel_ptr));\n"
" weights1=CONVERT_FLOAT4(vload4(offset+1,(__global float *)kernel_ptr));\n"
" weights2=CONVERT_FLOAT4(vload4(offset+2,(__global float *)kernel_ptr));\n"
" weights3=CONVERT_FLOAT4(vload4(offset+3,(__global float *)kernel_ptr));\n"
" #else\n"
" weights0=vload4(offset,(__global FLOAT *)kernel_ptr);\n"
" weights1=vload4(offset+1,(__global FLOAT *)kernel_ptr);\n"
" weights2=vload4(offset+2,(__global FLOAT *)kernel_ptr);\n"
" weights3=vload4(offset+3,(__global FLOAT *)kernel_ptr);\n"
" #endif\n"
" \n"
" out0.x += dot(weights0,in0);\n"
" out0.y += dot(weights1,in0);\n"
" out0.z += dot(weights2,in0);\n"
" out0.w += dot(weights3,in0);\n"
" out1.x += dot(weights0,in1);\n"
" out1.y += dot(weights1,in1);\n"
" out1.z += dot(weights2,in1);\n"
" out1.w += dot(weights3,in1);\n"
" out2.x += dot(weights0,in2);\n"
" out2.y += dot(weights1,in2);\n"
" out2.z += dot(weights2,in2);\n"
" out2.w += dot(weights3,in2);\n"
" out3.x += dot(weights0,in3);\n"
" out3.y += dot(weights1,in3);\n"
" out3.z += dot(weights2,in3);\n"
" out3.w += dot(weights3,in3);\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(FLOAT4)0);\n"
" out1=fmax(out1,(FLOAT4)0);\n"
" out2=fmax(out2,(FLOAT4)0);\n"
" out3=fmax(out3,(FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n"
" out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n"
" out2=clamp(out2,(FLOAT4)0,(FLOAT4)6);\n"
" out3=clamp(out3,(FLOAT4)0,(FLOAT4)6);\n"
"#endif\n"
" const int out_x_base=out_c_idx*out_w;\n"
" const int remain=out_w-out_w4_idx;\n"
" int output_idx=out_x_base+out_w4_idx;\n"
" \n"
" if (remain >= 4) {\n"
" WI_F(output,(int2)(output_idx,out_b_h_idx),out0);\n"
" WI_F(output,(int2)(output_idx+1,out_b_h_idx),out1);\n"
" WI_F(output,(int2)(output_idx+2,out_b_h_idx),out2);\n"
" WI_F(output,(int2)(output_idx+3,out_b_h_idx),out3);\n"
" } else if (remain == 3) {\n"
" WI_F(output,(int2)(output_idx,out_b_h_idx),out0);\n"
" WI_F(output,(int2)(output_idx+1,out_b_h_idx),out1);\n"
" WI_F(output,(int2)(output_idx+2,out_b_h_idx),out2);\n"
" } else if (remain == 2) {\n"
" WI_F(output,(int2)(output_idx,out_b_h_idx),out0);\n"
" WI_F(output,(int2)(output_idx+1,out_b_h_idx),out1);\n"
" } else if (remain == 1) {\n"
" WI_F(output,(int2)(output_idx,out_b_h_idx),out0);\n"
" }\n"
"}\n"
"__kernel\n"
"#if SET_ATTRIBUTE\n"
"__attribute__((work_group_size_hint(16,16,1)))\n"
"#endif\n"
"void conv_2d_1x1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_BUFFER)\n"
" __global const FLOAT *weights,\n"
"#else\n"
" __read_only image2d_t weights,\n"
"#endif\n"
" __read_only image2d_t bias,\n"
" __write_only image2d_t output,\n"
" __private const int2 input_shape,\n"
" __private const int in_channel_block,__private const int2 output_shape,\n"
" __private const int2 stride_shape,\n"
" __private const int output_width_4,\n"
" __private const int out_channel_blocks\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" ,__private const int blockDim\n"
" ,__private const int inChannel\n"
"#endif\n"
") {\n"
" const int output_channel_width_idx=get_global_id(0);\n"
" const int output_batch_height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(output_channel_width_idx,output_batch_height_idx);\n"
" const int output_channel_block_idx=output_channel_width_idx/output_width_4;\n"
" const int output_width_block_idx=output_channel_width_idx % output_width_4;\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int weight_ic_offset=output_channel_block_idx*8;\n"
" int weight_oc_offset=out_channel_blocks*8;\n"
"#else\n"
" int weight_ic_offset=output_channel_block_idx*16;\n"
" int weight_oc_offset=out_channel_blocks*16;\n"
"#endif\n"
" FLOAT4 out0=RI_F(bias,SAMPLER,(int2)(output_channel_block_idx,0));\n"
" FLOAT4 out1=out0;\n"
" FLOAT4 out2=out0;\n"
" FLOAT4 out3=out0;\n"
"#ifdef MNN_CONV_S1D1\n"
" int intput_width_idx0=output_width_block_idx << 2;\n"
" int intput_width_idx1=intput_width_idx0+1;\n"
" int intput_width_idx2=intput_width_idx0+2;\n"
" int intput_width_idx3=intput_width_idx0+3;\n"
"#else\n"
" int intput_width_idx0=mul24(output_width_block_idx,stride_shape.y*4);\n"
" int intput_width_idx1=intput_width_idx0+stride_shape.y;\n"
" int intput_width_idx2=intput_width_idx1+stride_shape.y;\n"
" int intput_width_idx3=intput_width_idx2+stride_shape.y;\n"
" intput_width_idx0=select(intput_width_idx0,INT_MIN,intput_width_idx0 >= input_shape.y);\n"
" intput_width_idx1=select(intput_width_idx1,INT_MIN,intput_width_idx1 >= input_shape.y);\n"
" intput_width_idx2=select(intput_width_idx2,INT_MIN,intput_width_idx2 >= input_shape.y);\n"
" intput_width_idx3=select(intput_width_idx3,INT_MIN,intput_width_idx3 >= input_shape.y);\n"
"#endif\n"
" int batch_index=output_batch_height_idx/output_shape.x;\n"
" int input_height_block_idx=mul24((output_batch_height_idx % output_shape.x),stride_shape.x)+batch_index*input_shape.x;\n"
" FLOAT4 in0;\n"
" FLOAT4 in1;\n"
" FLOAT4 in2;\n"
" FLOAT4 in3;\n"
" FLOAT4 weights0;\n"
" FLOAT4 weights1;\n"
" FLOAT4 weights2;\n"
" FLOAT4 weights3;\n"
" int weight_offset=output_channel_block_idx*in_channel_block*4*4;\n"
" for (int in_channel_block_idx=0; in_channel_block_idx<in_channel_block; ++in_channel_block_idx) {\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int kindex=(in_channel_block_idx*4)/blockDim*out_channel_blocks*8;\n"
" COMPUTE_FLOAT8 ScaleOffset0=CONVERT_COMPUTE_FLOAT8(vload8(output_channel_block_idx,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale0=(COMPUTE_FLOAT4)(ScaleOffset0.s0,ScaleOffset0.s2,ScaleOffset0.s4,ScaleOffset0.s6);\n"
" COMPUTE_FLOAT4 offset0=(COMPUTE_FLOAT4)(ScaleOffset0.s1,ScaleOffset0.s3,ScaleOffset0.s5,ScaleOffset0.s7);\n"
"#endif\n"
" int input_width_base=in_channel_block_idx*input_shape.y;\n"
" int weights_width_base=in_channel_block_idx << 2;\n"
" \n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" FLOAT16 weights=CONVERT_FLOAT16(vload16(0,kernel_ptr+weight_ic_offset+in_channel_block_idx*weight_oc_offset));\n"
" FLOAT4 weights0=CONVERT_FLOAT4(weights.s0123)*scale0+offset0;\n"
" FLOAT4 weights1=CONVERT_FLOAT4(weights.s4567)*scale0+offset0;\n"
" FLOAT4 weights2=CONVERT_FLOAT4(weights.s89ab)*scale0+offset0;\n"
" FLOAT4 weights3=CONVERT_FLOAT4(weights.scdef)*scale0+offset0;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" uchar8 charWeightsInt4=vload8(0,kernel_ptr+weight_ic_offset+in_channel_block_idx*weight_oc_offset);\n"
" char4 charWeights0=(char4)(0,0,0,0);\n"
" char4 charWeights1=(char4)(0,0,0,0);\n"
" char4 charWeights2=(char4)(0,0,0,0);\n"
" char4 charWeights3=(char4)(0,0,0,0);\n"
" charWeights0.x=(charWeightsInt4.s0 >> 4)-8;\n"
" charWeights0.y=(charWeightsInt4.s0 & MOD_NUM)-8;\n"
" charWeights0.z=(charWeightsInt4.s1 >> 4)-8;\n"
" charWeights0.w=(charWeightsInt4.s1 & MOD_NUM)-8;\n"
" charWeights1.x=(charWeightsInt4.s2 >> 4)-8;\n"
" charWeights1.y=(charWeightsInt4.s2 & MOD_NUM)-8;\n"
" charWeights1.z=(charWeightsInt4.s3 >> 4)-8;\n"
" charWeights1.w=(charWeightsInt4.s3 & MOD_NUM)- 8;\n"
" charWeights2.x=(charWeightsInt4.s4 >> 4)-8;\n"
" charWeights2.y=(charWeightsInt4.s4 & MOD_NUM)-8;\n"
" charWeights2.z=(charWeightsInt4.s5 >> 4)-8;\n"
" charWeights2.w=(charWeightsInt4.s5 & MOD_NUM)-8;\n"
" charWeights3.x=(charWeightsInt4.s6 >> 4)-8;\n"
" charWeights3.y=(charWeightsInt4.s6 & MOD_NUM)-8;\n"
" charWeights3.z=(charWeightsInt4.s7 >> 4)-8;\n"
" charWeights3.w=(charWeightsInt4.s7 & MOD_NUM)-8;\n"
" weights0=mad(CONVERT_FLOAT4(charWeights0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeights1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeights2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeights3),scale0,offset0);\n"
"#elif (defined USE_BUFFER)\n"
" weights0=vload4(weights_width_base,weights+weight_offset);\n"
" weights1=vload4(weights_width_base+1,weights+weight_offset);\n"
" weights2=vload4(weights_width_base+2,weights+weight_offset);\n"
" weights3=vload4(weights_width_base+3,weights+weight_offset);\n"
"#else\n"
" weights0=RI_F(weights,SAMPLER,(int2)(weights_width_base+0,output_channel_block_idx));\n"
" weights1=RI_F(weights,SAMPLER,(int2)(weights_width_base+1,output_channel_block_idx));\n"
" weights2=RI_F(weights,SAMPLER,(int2)(weights_width_base+2,output_channel_block_idx));\n"
" weights3=RI_F(weights,SAMPLER,(int2)(weights_width_base+3,output_channel_block_idx));\n"
"#endif\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n"
" in0=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx0,input_height_block_idx));\n"
" in1=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx1,input_height_block_idx));\n"
" in2=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx2,input_height_block_idx));\n"
" in3=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx3,input_height_block_idx));\n"
" CALCULATE_OUTPUT(0);\n"
" CALCULATE_OUTPUT(1);\n"
" CALCULATE_OUTPUT(2);\n"
" CALCULATE_OUTPUT(3);\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(FLOAT4)0);\n"
" out1=fmax(out1,(FLOAT4)0);\n"
" out2=fmax(out2,(FLOAT4)0);\n"
" out3=fmax(out3,(FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n"
" out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n"
" out2=clamp(out2,(FLOAT4)0,(FLOAT4)6);\n"
" out3=clamp(out3,(FLOAT4)0,(FLOAT4)6);\n"
"#endif\n"
" const int out_x_base=mul24(output_channel_block_idx,output_shape.y);\n"
" int out_x_idx=output_width_block_idx << 2;\n"
" const int remain=output_shape.y-out_x_idx;\n"
" int output_idx=out_x_base+out_x_idx;\n"
" if (remain >= 4) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n"
" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n"
" WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out2);\n"
" WI_F(output,(int2)(output_idx+3,output_batch_height_idx),out3);\n"
" } else if (remain == 3) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n"
" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n"
" WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out2);\n"
" } else if (remain == 2) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n"
" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n"
" } else if (remain == 1) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n"
" }\n"
"}\n"
"__kernel\n"
"#if SET_ATTRIBUTE\n"
"__attribute__((work_group_size_hint(16,16,1)))\n"
"#endif\n"
"void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_BUFFER)\n"
" __global const FLOAT *weights,\n"
"#else\n"
" __read_only image2d_t weights,\n"
"#endif\n"
" __read_only image2d_t bias,\n"
" __write_only image2d_t output,\n"
" __private const int2 input_shape,\n"
" __private const int in_channel_block,__private const int2 output_shape,\n"
" __private const int2 stride_shape,\n"
" __private const int output_width_4,\n"
" __private const int out_channel_blocks\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" ,__private const int blockDim\n"
" ,__private const int inChannel\n"
"#endif\n"
") {\n"
" const int output_channel_width_idx=get_global_id(0);\n"
" const int output_batch_height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(output_channel_width_idx,output_batch_height_idx);\n"
" const int output_channel_block_idx=output_channel_width_idx/output_width_4;\n"
" const int output_width_block_idx=output_channel_width_idx % output_width_4;\n"
" const int output_channel_idx=output_channel_block_idx << 1;\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int weight_ic_offset=output_channel_block_idx*16;\n"
" int weight_oc_offset=out_channel_blocks*8;\n"
"#else\n"
" int weight_ic_offset=output_channel_block_idx*32;\n"
" int weight_oc_offset=out_channel_blocks*16;\n"
"#endif\n"
" FLOAT4 out0=RI_F(bias,SAMPLER,(int2)(output_channel_idx,0));\n"
" FLOAT4 out1=out0;\n"
" FLOAT4 out2=out0;\n"
" FLOAT4 out3=out0;\n"
" \n"
" FLOAT4 out4=RI_F(bias,SAMPLER,(int2)(output_channel_idx+1,0));\n"
" FLOAT4 out5=out4;\n"
" FLOAT4 out6=out4;\n"
" FLOAT4 out7=out4;\n"
"#ifdef MNN_CONV_S1D1\n"
" int intput_width_idx0=output_width_block_idx << 2;\n"
" int intput_width_idx1=intput_width_idx0+1;\n"
" int intput_width_idx2=intput_width_idx0+2;\n"
" int intput_width_idx3=intput_width_idx0+3;\n"
"#else\n"
" int intput_width_idx0=mul24(output_width_block_idx,stride_shape.y*4);\n"
" int intput_width_idx1=intput_width_idx0+stride_shape.y;\n"
" int intput_width_idx2=intput_width_idx1+stride_shape.y;\n"
" int intput_width_idx3=intput_width_idx2+stride_shape.y;\n"
" intput_width_idx0=select(intput_width_idx0,INT_MIN,intput_width_idx0 >= input_shape.y);\n"
" intput_width_idx1=select(intput_width_idx1,INT_MIN,intput_width_idx1 >= input_shape.y);\n"
" intput_width_idx2=select(intput_width_idx2,INT_MIN,intput_width_idx2 >= input_shape.y);\n"
" intput_width_idx3=select(intput_width_idx3,INT_MIN,intput_width_idx3 >= input_shape.y);\n"
"#endif\n"
" int batch_index=output_batch_height_idx/output_shape.x;\n"
" int input_height_block_idx=mul24((output_batch_height_idx % output_shape.x),stride_shape.x)+batch_index*input_shape.x;\n"
" FLOAT4 in0;\n"
" FLOAT4 in1;\n"
" FLOAT4 in2;\n"
" FLOAT4 in3;\n"
" FLOAT4 weights0;\n"
" FLOAT4 weights1;\n"
" FLOAT4 weights2;\n"
" FLOAT4 weights3;\n"
" FLOAT4 weights4;\n"
" FLOAT4 weights5;\n"
" FLOAT4 weights6;\n"
" FLOAT4 weights7;\n"
" int weight_offset=output_channel_idx*in_channel_block*4*4;\n"
" int weight_offset1=weight_offset+in_channel_block*4*4;\n"
" for (int in_channel_block_idx=0; in_channel_block_idx<in_channel_block; ++in_channel_block_idx) {\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int kindex=(in_channel_block_idx*4)/blockDim*out_channel_blocks*8;\n"
" COMPUTE_FLOAT8 ScaleOffset0=CONVERT_COMPUTE_FLOAT8(vload8(output_channel_idx,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale0=(COMPUTE_FLOAT4)(ScaleOffset0.s0,ScaleOffset0.s2,ScaleOffset0.s4,ScaleOffset0.s6);\n"
" COMPUTE_FLOAT4 offset0=(COMPUTE_FLOAT4)(ScaleOffset0.s1,ScaleOffset0.s3,ScaleOffset0.s5,ScaleOffset0.s7);\n"
" COMPUTE_FLOAT8 ScaleOffset1=CONVERT_COMPUTE_FLOAT8(vload8(output_channel_idx+1,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale1=(COMPUTE_FLOAT4)(ScaleOffset1.s0,ScaleOffset1.s2,ScaleOffset1.s4,ScaleOffset1.s6);\n"
" COMPUTE_FLOAT4 offset1=(COMPUTE_FLOAT4)(ScaleOffset1.s1,ScaleOffset1.s3,ScaleOffset1.s5,ScaleOffset1.s7);\n"
"#endif\n"
" \n"
" int input_width_base=in_channel_block_idx*input_shape.y;\n"
" int weights_width_base=in_channel_block_idx << 2;\n"
" in0=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx0,input_height_block_idx));\n"
" in1=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx1,input_height_block_idx));\n"
" in2=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx2,input_height_block_idx));\n"
" in3=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx3,input_height_block_idx));\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" FLOAT16 weightsInt80=CONVERT_FLOAT16(vload16(0,kernel_ptr+weight_ic_offset+in_channel_block_idx*weight_oc_offset));\n"
" FLOAT16 weightsInt81=CONVERT_FLOAT16(vload16(0,kernel_ptr+16+weight_ic_offset+in_channel_block_idx*weight_oc_offset));\n"
" FLOAT4 weights0=CONVERT_FLOAT4(weightsInt80.s0123)*scale0+offset0;\n"
" FLOAT4 weights1=CONVERT_FLOAT4(weightsInt80.s4567)*scale0+offset0;\n"
" FLOAT4 weights2=CONVERT_FLOAT4(weightsInt80.s89ab)*scale0+offset0;\n"
" FLOAT4 weights3=CONVERT_FLOAT4(weightsInt80.scdef)*scale0+offset0;\n"
" FLOAT4 weights4=CONVERT_FLOAT4(weightsInt81.s0123)*scale1+offset1;\n"
" FLOAT4 weights5=CONVERT_FLOAT4(weightsInt81.s4567)*scale1+offset1;\n"
" FLOAT4 weights6=CONVERT_FLOAT4(weightsInt81.s89ab)*scale1+offset1;\n"
" FLOAT4 weights7=CONVERT_FLOAT4(weightsInt81.scdef)*scale1+offset1;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" uchar16 charWeightsInt4=vload16(0,kernel_ptr+weight_ic_offset+in_channel_block_idx*weight_oc_offset);\n"
" char4 charWeights0=(char4)(0,0,0,0);\n"
" char4 charWeights1=(char4)(0,0,0,0);\n"
" char4 charWeights2=(char4)(0,0,0,0);\n"
" char4 charWeights3=(char4)(0,0,0,0);\n"
" char4 charWeights4=(char4)(0,0,0,0);\n"
" char4 charWeights5=(char4)(0,0,0,0);\n"
" char4 charWeights6=(char4)(0,0,0,0);\n"
" char4 charWeights7=(char4)(0,0,0,0);\n"
" charWeights0.x=(charWeightsInt4.s0 >> 4)-8;\n"
" charWeights0.y=(charWeightsInt4.s0 & MOD_NUM)-8;\n"
" charWeights0.z=(charWeightsInt4.s1 >> 4)-8;\n"
" charWeights0.w=(charWeightsInt4.s1 & MOD_NUM)-8;\n"
" charWeights1.x=(charWeightsInt4.s2 >> 4)-8;\n"
" charWeights1.y=(charWeightsInt4.s2 & MOD_NUM)-8;\n"
" charWeights1.z=(charWeightsInt4.s3 >> 4)-8;\n"
" charWeights1.w=(charWeightsInt4.s3 & MOD_NUM)-8;\n"
" charWeights2.x=(charWeightsInt4.s4 >> 4)-8;\n"
" charWeights2.y=(charWeightsInt4.s4 & MOD_NUM)-8;\n"
" charWeights2.z=(charWeightsInt4.s5 >> 4)-8;\n"
" charWeights2.w=(charWeightsInt4.s5 & MOD_NUM)-8;\n"
" charWeights3.x=(charWeightsInt4.s6 >> 4)-8;\n"
" charWeights3.y=(charWeightsInt4.s6 & MOD_NUM)-8;\n"
" charWeights3.z=(charWeightsInt4.s7 >> 4)-8;\n"
" charWeights3.w=(charWeightsInt4.s7 & MOD_NUM)-8;\n"
" charWeights4.x=(charWeightsInt4.s8 >> 4)-8;\n"
" charWeights4.y=(charWeightsInt4.s8 & MOD_NUM)-8;\n"
" charWeights4.z=(charWeightsInt4.s9 >> 4)-8;\n"
" charWeights4.w=(charWeightsInt4.s9 & MOD_NUM)-8;\n"
" charWeights5.x=(charWeightsInt4.sa >> 4)-8;\n"
" charWeights5.y=(charWeightsInt4.sa & MOD_NUM)-8;\n"
" charWeights5.z=(charWeightsInt4.sb >> 4)-8;\n"
" charWeights5.w=(charWeightsInt4.sb & MOD_NUM)-8;\n"
" charWeights6.x=(charWeightsInt4.sc >> 4)-8;\n"
" charWeights6.y=(charWeightsInt4.sc & MOD_NUM)-8;\n"
" charWeights6.z=(charWeightsInt4.sd >> 4)-8;\n"
" charWeights6.w=(charWeightsInt4.sd & MOD_NUM)-8;\n"
" charWeights7.x=(charWeightsInt4.se >> 4)-8;\n"
" charWeights7.y=(charWeightsInt4.se & MOD_NUM)-8;\n"
" charWeights7.z=(charWeightsInt4.sf >> 4)-8;\n"
" charWeights7.w=(charWeightsInt4.sf & MOD_NUM)-8;\n"
" weights0=mad(CONVERT_FLOAT4(charWeights0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeights1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeights2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeights3),scale0,offset0);\n"
" weights4=mad(CONVERT_FLOAT4(charWeights4),scale1,offset1);\n"
" weights5=mad(CONVERT_FLOAT4(charWeights5),scale1,offset1);\n"
" weights6=mad(CONVERT_FLOAT4(charWeights6),scale1,offset1);\n"
" weights7=mad(CONVERT_FLOAT4(charWeights7),scale1,offset1);\n"
"#elif (defined USE_BUFFER)\n"
" weights0=vload4(weights_width_base,weights+weight_offset);\n"
" weights1=vload4(weights_width_base+1,weights+weight_offset);\n"
" weights2=vload4(weights_width_base+2,weights+weight_offset);\n"
" weights3=vload4(weights_width_base+3,weights+weight_offset);\n"
" weights4=vload4(weights_width_base,weights+weight_offset1);\n"
" weights5=vload4(weights_width_base+1,weights+weight_offset1);\n"
" weights6=vload4(weights_width_base+2,weights+weight_offset1);\n"
" weights7=vload4(weights_width_base+3,weights+weight_offset1);\n"
"#else\n"
" weights0=RI_F(weights,SAMPLER,(int2)(weights_width_base+0,output_channel_idx));\n"
" weights1=RI_F(weights,SAMPLER,(int2)(weights_width_base+1,output_channel_idx));\n"
" weights2=RI_F(weights,SAMPLER,(int2)(weights_width_base+2,output_channel_idx));\n"
" weights3=RI_F(weights,SAMPLER,(int2)(weights_width_base+3,output_channel_idx));\n"
" \n"
" weights4=RI_F(weights,SAMPLER,(int2)(weights_width_base+0,output_channel_idx+1));\n"
" weights5=RI_F(weights,SAMPLER,(int2)(weights_width_base+1,output_channel_idx+1));\n"
" weights6=RI_F(weights,SAMPLER,(int2)(weights_width_base+2,output_channel_idx+1));\n"
" weights7=RI_F(weights,SAMPLER,(int2)(weights_width_base+3,output_channel_idx+1));\n"
"#endif\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights4,weights5,weights6,weights7);\n"
" CALCULATE_OUTPUT(0);\n"
" CALCULATE_OUTPUT(1);\n"
" CALCULATE_OUTPUT(2);\n"
" CALCULATE_OUTPUT(3);\n"
" \n"
" CALCULATE_OUTPUT_WEIGHTS4(4,0);\n"
" CALCULATE_OUTPUT_WEIGHTS4(5,1);\n"
" CALCULATE_OUTPUT_WEIGHTS4(6,2);\n"
" CALCULATE_OUTPUT_WEIGHTS4(7,3);\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(FLOAT4)0);\n"
" out1=fmax(out1,(FLOAT4)0);\n"
" out2=fmax(out2,(FLOAT4)0);\n"
" out3=fmax(out3,(FLOAT4)0);\n"
" out4=fmax(out4,(FLOAT4)0);\n"
" out5=fmax(out5,(FLOAT4)0);\n"
" out6=fmax(out6,(FLOAT4)0);\n"
" out7=fmax(out7,(FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n"
" out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n"
" out2=clamp(out2,(FLOAT4)0,(FLOAT4)6);\n"
" out3=clamp(out3,(FLOAT4)0,(FLOAT4)6);\n"
" out4=clamp(out4,(FLOAT4)0,(FLOAT4)6);\n"
" out5=clamp(out5,(FLOAT4)0,(FLOAT4)6);\n"
" out6=clamp(out6,(FLOAT4)0,(FLOAT4)6);\n"
" out7=clamp(out7,(FLOAT4)0,(FLOAT4)6);\n"
"#endif\n"
" const int out_x_base=mul24(output_channel_idx,output_shape.y);\n"
" int out_x_idx=output_width_block_idx << 2;\n"
" const int remain=output_shape.y-out_x_idx;\n"
" int output_idx=out_x_base+out_x_idx;\n"
" if (remain >= 4) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n"
" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n"
" WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out2);\n"
" WI_F(output,(int2)(output_idx+3,output_batch_height_idx),out3);\n"
" } else if (remain == 3) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n"
" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n"
" WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out2);\n"
" } else if (remain == 2) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n"
" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n"
" } else if (remain == 1) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n"
" }\n"
" \n"
" if(output_channel_idx+1 >= out_channel_blocks)\n"
" return;\n"
" output_idx += output_shape.y;\n"
" if (remain >= 4) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out4);\n"
" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out5);\n"
" WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out6);\n"
" WI_F(output,(int2)(output_idx+3,output_batch_height_idx),out7);\n"
" } else if (remain == 3) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out4);\n"
" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out5);\n"
" WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out6);\n"
" } else if (remain == 2) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out4);\n"
" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out5);\n"
" } else if (remain == 1) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out4);\n"
" }\n"
"}\n"
"__kernel\n"
"#if SET_ATTRIBUTE\n"
"__attribute__((work_group_size_hint(16,16,1)))\n"
"#endif\n"
"void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_BUFFER)\n"
" __global const FLOAT *weights,\n"
"#else\n"
" __read_only image2d_t weights,\n"
"#endif\n"
"#ifdef BIAS\n"
" __read_only image2d_t bias,\n"
"#endif\n"
" __write_only image2d_t output,\n"
" __private const int2 input_shape,\n"
" __private const int in_channel_block_length,\n"
" __private const int2 output_shape,\n"
" __private const int2 weights_shape,\n"
" __private const int2 stride_shape,\n"
" __private const int2 padding_shape,\n"
" __private const int2 dilation_shape,\n"
" __private const int out_width_blocks,\n"
" __private const int out_channel_blocks,\n"
" __private const int out_height_blocks\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" ,__private const int blockDim\n"
" ,__private const int inChannel\n"
"#endif\n"
") {\n"
" const int output_channel_width_idx=get_global_id(0);\n"
" const int output_batch_height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(output_channel_width_idx,output_batch_height_idx);\n"
" const int out_channel_block_idx=output_channel_width_idx/out_width_blocks;\n"
" const int out_height_block_idx=output_channel_width_idx % out_width_blocks;\n"
"#ifdef BIAS\n"
" FLOAT4 out0=RI_F(bias,SAMPLER,(int2)(out_channel_block_idx,0));\n"
"#else\n"
" FLOAT4 out0=(FLOAT4)0;\n"
"#endif\n"
" FLOAT4 out1=out0;\n"
" FLOAT4 out2=out0;\n"
" FLOAT4 out3=out0;\n"
" int in_width0=mad24(out_height_block_idx,stride_shape.y<<2,-padding_shape.y);\n"
" int in_width1=in_width0+stride_shape.y;\n"
" int in_width2=in_width0+stride_shape.y*2;\n"
" int in_width3=in_width0+stride_shape.y*3;\n"
" \n"
"#ifdef MNN_CONV_S1D1\n"
" const int height_start=mad24((output_batch_height_idx % output_shape.x),1,-padding_shape.x);\n"
" const int kh_start=select(0,(-height_start),height_start<0);\n"
" int in_height_start=kh_start+height_start;\n"
" int in_height_end=min(weights_shape.x+height_start,input_shape.x);\n"
" const int batch_idx=mul24((output_batch_height_idx/output_shape.x),input_shape.x);\n"
" const int weights_h_idx=mul24(out_channel_block_idx,mul24(weights_shape.y,weights_shape.x))+mul24(select(0,(-height_start),height_start<0),weights_shape.y);\n"
"#else\n"
" const int height_start=mad24((output_batch_height_idx % output_shape.x),stride_shape.x,-padding_shape.x);\n"
" const int kh_start=select(0,(-height_start+dilation_shape.x-1)/dilation_shape.x,height_start<0);\n"
" int in_height_start=mad24(kh_start,dilation_shape.x,height_start);\n"
" int in_height_end=min(mad24(weights_shape.x,dilation_shape.x,height_start),input_shape.x);\n"
" const int batch_idx=mul24((output_batch_height_idx/output_shape.x),input_shape.x);\n"
" const int weights_h_idx=mul24(out_channel_block_idx,mul24(weights_shape.y,weights_shape.x))+mul24(select(0,(-height_start+dilation_shape.x-1)/dilation_shape.x,height_start<0),weights_shape.y);\n"
"#endif\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n"
" const int weight_oc_offset=out_channel_blocks*weights_shape.x*weights_shape.y*4;\n"
"#endif\n"
" FLOAT4 in0,in1,in2,in3;\n"
" FLOAT4 weights0,weights1,weights2,weights3;\n"
" for (int in_channel_block_idx=0; in_channel_block_idx<in_channel_block_length; ++in_channel_block_idx) {\n"
" \n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int kindex=(in_channel_block_idx*4)/blockDim*out_channel_blocks*8;\n"
" COMPUTE_FLOAT8 ScaleOffset0=CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale0=(COMPUTE_FLOAT4)(ScaleOffset0.s0,ScaleOffset0.s2,ScaleOffset0.s4,ScaleOffset0.s6);\n"
" COMPUTE_FLOAT4 offset0=(COMPUTE_FLOAT4)(ScaleOffset0.s1,ScaleOffset0.s3,ScaleOffset0.s5,ScaleOffset0.s7);\n"
"#endif\n"
" \n"
" const int in_idx=mul24(in_channel_block_idx,input_shape.y);\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n"
" int weight_offset=((((4*in_channel_block_idx+0)* out_channel_blocks+out_channel_block_idx) *weights_shape.x+kh_start)*weights_shape.y+0)*4;\n"
"#else\n"
" int weights_x_idx=in_channel_block_idx << 2;\n"
" int weights_y_idx=weights_h_idx;\n"
"#endif\n"
" for (int iy=in_height_start; iy<in_height_end; iy += dilation_shape.x) {\n"
" int in_hb_value=iy+batch_idx;\n"
"#ifdef MNN_CONV_S1D1\n"
" {\n"
" READ_INPUT_IMAGE(0,0);\n"
" READ_INPUT_IMAGE(1,0);\n"
" READ_INPUT_IMAGE(2,0);\n"
" READ_INPUT_IMAGE(3,0);\n"
" \n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" char4 charWeight0=vload4(0,kernel_ptr+weight_offset);\n"
" char4 charWeight1=vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n"
" char4 charWeight2=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*2);\n"
" char4 charWeight3=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*3);\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" weight_offset += 4;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" uchar2 charWeightInt40=vload2(0,kernel_ptr+weight_offset/2);\n"
" uchar2 charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2);\n"
" uchar2 charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*2/2);\n"
" uchar2 charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*3/2);\n"
" char4 charWeight0=(char4)(0,0,0,0);\n"
" char4 charWeight1=(char4)(0,0,0,0);\n"
" char4 charWeight2=(char4)(0,0,0,0);\n"
" char4 charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" weight_offset += 4;\n"
"#elif (defined USE_BUFFER)\n"
" weights0=vload4(0,weights+weight_offset);\n"
" weights1=vload4(0,weights+weight_offset+weight_oc_offset);\n"
" weights2=vload4(0,weights+weight_offset+weight_oc_offset*2);\n"
" weights3=vload4(0,weights+weight_offset+weight_oc_offset*3);\n"
" weight_offset += 4;\n"
"#else\n"
" weights0=RI_F(weights,SAMPLER,(int2)(weights_x_idx+0,weights_y_idx));\n"
" weights1=RI_F(weights,SAMPLER,(int2)(weights_x_idx+1,weights_y_idx));\n"
" weights2=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weights_y_idx));\n"
" weights3=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weights_y_idx++));\n"
"#endif\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n"
" CALCULATE_OUTPUT(0);\n"
" CALCULATE_OUTPUT(1);\n"
" CALCULATE_OUTPUT(2);\n"
" CALCULATE_OUTPUT(3);\n"
" }\n"
" for (int w=1; w<weights_shape.y; w++){\n"
" in0=in1;\n"
" in1=in2;\n"
" in2=in3;\n"
" READ_INPUT_IMAGE(3,w);\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" char4 charWeight0=vload4(0,kernel_ptr+weight_offset);\n"
" char4 charWeight1=vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n"
" char4 charWeight2=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*2);\n"
" char4 charWeight3=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*3);\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" weight_offset += 4;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" uchar2 charWeightInt40=vload2(0,kernel_ptr+weight_offset/2);\n"
" uchar2 charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2);\n"
" uchar2 charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*2/2);\n"
" uchar2 charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*3/2);\n"
" char4 charWeight0=(char4)(0,0,0,0);\n"
" char4 charWeight1=(char4)(0,0,0,0);\n"
" char4 charWeight2=(char4)(0,0,0,0);\n"
" char4 charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" weight_offset += 4;\n"
"#elif (defined USE_BUFFER)\n"
" weights0=vload4(0,weights+weight_offset);\n"
" weights1=vload4(0,weights+weight_offset+weight_oc_offset);\n"
" weights2=vload4(0,weights+weight_offset+weight_oc_offset*2);\n"
" weights3=vload4(0,weights+weight_offset+weight_oc_offset*3);\n"
" weight_offset += 4;\n"
"#else\n"
" weights0=RI_F(weights,SAMPLER,(int2)(weights_x_idx+0,weights_y_idx));\n"
" weights1=RI_F(weights,SAMPLER,(int2)(weights_x_idx+1,weights_y_idx));\n"
" weights2=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weights_y_idx));\n"
" weights3=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weights_y_idx++));\n"
"#endif\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n"
" CALCULATE_OUTPUT(0);\n"
" CALCULATE_OUTPUT(1);\n"
" CALCULATE_OUTPUT(2);\n"
" CALCULATE_OUTPUT(3);\n"
" }\n"
"#else\n"
" for (int w=0; w<weights_shape.y; w++) {\n"
" int input_width_base=mul24(w,dilation_shape.y);\n"
" READ_INPUT_IMAGE(0,input_width_base);\n"
" READ_INPUT_IMAGE(1,input_width_base);\n"
" READ_INPUT_IMAGE(2,input_width_base);\n"
" READ_INPUT_IMAGE(3,input_width_base);\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" char4 charWeight0=vload4(0,kernel_ptr+weight_offset);\n"
" char4 charWeight1=vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n"
" char4 charWeight2=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*2);\n"
" char4 charWeight3=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*3);\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" weight_offset += 4;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" uchar2 charWeightInt40=vload2(0,kernel_ptr+weight_offset/2);\n"
" uchar2 charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2);\n"
" uchar2 charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*2/2);\n"
" uchar2 charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*3/2);\n"
" char4 charWeight0=(char4)(0,0,0,0);\n"
" char4 charWeight1=(char4)(0,0,0,0);\n"
" char4 charWeight2=(char4)(0,0,0,0);\n"
" char4 charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" weight_offset += 4;\n"
"#elif (defined USE_BUFFER)\n"
" weights0=vload4(0,weights+weight_offset);\n"
" weights1=vload4(0,weights+weight_offset+weight_oc_offset);\n"
" weights2=vload4(0,weights+weight_offset+weight_oc_offset*2);\n"
" weights3=vload4(0,weights+weight_offset+weight_oc_offset*3);\n"
" weight_offset += 4;\n"
"#else\n"
" weights0=RI_F(weights,SAMPLER,(int2)(weights_x_idx+0,weights_y_idx)); \n"
" weights1=RI_F(weights,SAMPLER,(int2)(weights_x_idx+1,weights_y_idx)); \n"
" weights2=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weights_y_idx)); \n"
" weights3=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weights_y_idx++));\n"
"#endif\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n"
" CALCULATE_OUTPUT(0);\n"
" CALCULATE_OUTPUT(1);\n"
" CALCULATE_OUTPUT(2);\n"
" CALCULATE_OUTPUT(3);\n"
" }\n"
"#endif\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(FLOAT4)0);\n"
" out1=fmax(out1,(FLOAT4)0);\n"
" out2=fmax(out2,(FLOAT4)0);\n"
" out3=fmax(out3,(FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n"
" out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n"
" out2=clamp(out2,(FLOAT4)0,(FLOAT4)6);\n"
" out3=clamp(out3,(FLOAT4)0,(FLOAT4)6);\n"
"#endif\n"
" const int out_x_base=mul24(out_channel_block_idx,output_shape.y);\n"
" int out_x_idx=out_height_block_idx << 2;\n"
" const int remain=output_shape.y-out_x_idx;\n"
" int output_idx=out_x_base+out_x_idx;\n"
" if (remain >= 4) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n"
" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n"
" WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out2);\n"
" WI_F(output,(int2)(output_idx+3,output_batch_height_idx),out3);\n"
" } else if (remain == 3) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n"
" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n"
" WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out2);\n"
" } else if (remain == 2) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n"
" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n"
" } else if (remain == 1) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n"
" }\n"
"}\n"
"__kernel\n"
"#if SET_ATTRIBUTE\n"
"__attribute__((work_group_size_hint(16,16,1)))\n"
"#endif\n"
"void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_BUFFER)\n"
" __global const FLOAT *weights,\n"
"#else\n"
" __read_only image2d_t weights,\n"
"#endif\n"
"#ifdef BIAS\n"
" __read_only image2d_t bias,\n"
"#endif\n"
" __write_only image2d_t output,\n"
" __private const int2 input_shape,\n"
" __private const int in_channel_block_length,\n"
" __private const int2 output_shape,\n"
" __private const int2 weights_shape,\n"
" __private const int2 stride_shape,\n"
" __private const int2 padding_shape,\n"
" __private const int2 dilation_shape,\n"
" __private const int out_width_blocks,\n"
" __private const int out_channel_blocks,\n"
" __private const int out_height_blocks\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" ,__private const int blockDim\n"
" ,__private const int inChannel\n"
"#endif\n"
") {\n"
" const int output_channel_width_idx=get_global_id(0);\n"
" const int output_batch_height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(output_channel_width_idx,output_batch_height_idx);\n"
" const int out_channel_block_idx=(output_channel_width_idx/out_width_blocks) << 1;\n"
" const int out_width_block_idx=output_channel_width_idx % out_width_blocks;\n"
" const int out_height_block_idx=(output_batch_height_idx % out_height_blocks);\n"
" const int out_batch_block_idx=output_batch_height_idx/out_height_blocks;\n"
"#ifdef BIAS\n"
" FLOAT4 out0=RI_F(bias,SAMPLER,(int2)(out_channel_block_idx,0));\n"
" FLOAT4 out4=RI_F(bias,SAMPLER,(int2)(out_channel_block_idx+1,0));\n"
"#else\n"
" FLOAT4 out0=(FLOAT4)0;\n"
" FLOAT4 out4=(FLOAT4)0;\n"
"#endif\n"
" FLOAT4 out1=out0;\n"
" FLOAT4 out2=out0;\n"
" FLOAT4 out3=out0;\n"
" FLOAT4 out5=out4;\n"
" FLOAT4 out6=out4;\n"
" FLOAT4 out7=out4;\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n"
" const int weight_oc_offset=weights_shape.x*weights_shape.y*4;\n"
" const int weight_ic_offset=out_channel_blocks*weight_oc_offset;\n"
"#endif\n"
" int in_width0=mad24(out_width_block_idx,stride_shape.y,-padding_shape.y);\n"
" int in_height0=mad24(out_height_block_idx,stride_shape.x<<2,-padding_shape.x);\n"
" int in_height1=in_height0+stride_shape.x;\n"
" int in_height2=in_height1+stride_shape.x;\n"
" int in_height3=in_height2+stride_shape.x;\n"
" int weight_size=mul24(weights_shape.y,weights_shape.x);\n"
" \n"
" const int weights_h_idx=mul24(out_channel_block_idx,weight_size);\n"
" const int batch_idx=mul24(out_batch_block_idx,input_shape.x);\n"
" \n"
" FLOAT4 in0,in1,in2,in3;\n"
" FLOAT4 weights0,weights1,weights2,weights3,weights4,weights5,weights6,weights7;\n"
" for (int in_channel_block_idx=0; in_channel_block_idx<in_channel_block_length; ++in_channel_block_idx) {\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int kindex=(in_channel_block_idx*4)/blockDim*out_channel_blocks*8;\n"
" COMPUTE_FLOAT8 ScaleOffset0=CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale0=(COMPUTE_FLOAT4)(ScaleOffset0.s0,ScaleOffset0.s2,ScaleOffset0.s4,ScaleOffset0.s6);\n"
" COMPUTE_FLOAT4 offset0=(COMPUTE_FLOAT4)(ScaleOffset0.s1,ScaleOffset0.s3,ScaleOffset0.s5,ScaleOffset0.s7);\n"
" COMPUTE_FLOAT8 ScaleOffset1=CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx+1,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale1=(COMPUTE_FLOAT4)(ScaleOffset1.s0,ScaleOffset1.s2,ScaleOffset1.s4,ScaleOffset1.s6);\n"
" COMPUTE_FLOAT4 offset1=(COMPUTE_FLOAT4)(ScaleOffset1.s1,ScaleOffset1.s3,ScaleOffset1.s5,ScaleOffset1.s7);\n"
" \n"
"#endif\n"
" const int in_idx=mul24(in_channel_block_idx,input_shape.y);\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n"
" int weight_offset=((((4*in_channel_block_idx+0)* out_channel_blocks+out_channel_block_idx) *weights_shape.x+0)*weights_shape.y+0)*4;\n"
"#else\n"
" int weights_x_idx=in_channel_block_idx << 2;\n"
" int weights_y_idx=weights_h_idx;\n"
"#endif\n"
" for (int iy=0; iy<weights_shape.x*dilation_shape.x; iy += dilation_shape.x) {\n"
" int h0=select(in_height0+iy+batch_idx,-1,(in_height0+iy<0 || in_height0+iy >= input_shape.x));\n"
" int h1=select(in_height1+iy+batch_idx,-1,(in_height1+iy<0 || in_height1+iy >= input_shape.x));\n"
" int h2=select(in_height2+iy+batch_idx,-1,(in_height2+iy<0 || in_height2+iy >= input_shape.x));\n"
" int h3=select(in_height3+iy+batch_idx,-1,(in_height3+iy<0 || in_height3+iy >= input_shape.x));\n"
" for (int ix=0; ix<weights_shape.y*dilation_shape.y; ix += dilation_shape.y) {\n"
" int w0=select(in_width0+ix+in_idx,-1,(in_width0+ix<0 || in_width0+ix >= input_shape.y));\n"
" \n"
" in0=RI_F(input,SAMPLER,(int2)(w0,h0));\n"
" in1=RI_F(input,SAMPLER,(int2)(w0,h1));\n"
" in2=RI_F(input,SAMPLER,(int2)(w0,h2));\n"
" in3=RI_F(input,SAMPLER,(int2)(w0,h3));\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" char4 charWeight0=vload4(0,kernel_ptr+weight_offset);\n"
" char4 charWeight1=vload4(0,kernel_ptr+weight_offset+weight_ic_offset);\n"
" char4 charWeight2=vload4(0,kernel_ptr+weight_offset+weight_ic_offset*2);\n"
" char4 charWeight3=vload4(0,kernel_ptr+weight_offset+weight_ic_offset*3);\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" charWeight0=vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n"
" charWeight1=vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset);\n"
" charWeight2=vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*2);\n"
" charWeight3=vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*3);\n"
" weights4=mad(CONVERT_FLOAT4(charWeight0),scale1,offset1);\n"
" weights5=mad(CONVERT_FLOAT4(charWeight1),scale1,offset1);\n"
" weights6=mad(CONVERT_FLOAT4(charWeight2),scale1,offset1);\n"
" weights7=mad(CONVERT_FLOAT4(charWeight3),scale1,offset1);\n"
" weight_offset += 4;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" uchar2 charWeightInt40=vload2(0,kernel_ptr+weight_offset/2);\n"
" uchar2 charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_ic_offset/2);\n"
" uchar2 charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_ic_offset*2/2);\n"
" uchar2 charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_ic_offset*3/2);\n"
" char4 charWeight0=(char4)(0,0,0,0);\n"
" char4 charWeight1=(char4)(0,0,0,0);\n"
" char4 charWeight2=(char4)(0,0,0,0);\n"
" char4 charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)- 8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" charWeightInt40=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2);\n"
" charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset/2);\n"
" charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset*2/2);\n"
" charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset*3/2);\n"
" charWeight0=(char4)(0,0,0,0);\n"
" charWeight1=(char4)(0,0,0,0);\n"
" charWeight2=(char4)(0,0,0,0);\n"
" charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)- 8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)- 8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" weights4=mad(CONVERT_FLOAT4(charWeight0),scale1,offset1);\n"
" weights5=mad(CONVERT_FLOAT4(charWeight1),scale1,offset1);\n"
" weights6=mad(CONVERT_FLOAT4(charWeight2),scale1,offset1);\n"
" weights7=mad(CONVERT_FLOAT4(charWeight3),scale1,offset1);\n"
" weight_offset += 4;\n"
"#elif (defined USE_BUFFER)\n"
" weights0=vload4(0,weights+weight_offset);\n"
" weights1=vload4(0,weights+weight_offset+weight_ic_offset);\n"
" weights2=vload4(0,weights+weight_offset+weight_ic_offset*2);\n"
" weights3=vload4(0,weights+weight_offset+weight_ic_offset*3);\n"
" weights4=vload4(0,weights+weight_offset+weight_oc_offset);\n"
" weights5=vload4(0,weights+weight_offset+weight_ic_offset+weight_oc_offset);\n"
" weights6=vload4(0,weights+weight_offset+weight_ic_offset*2+weight_oc_offset);\n"
" weights7=vload4(0,weights+weight_offset+weight_ic_offset*3+weight_oc_offset);\n"
" weight_offset += 4;\n"
"#else\n"
" weights0=RI_F(weights,SAMPLER,(int2)(weights_x_idx+0,weights_y_idx));\n"
" weights1=RI_F(weights,SAMPLER,(int2)(weights_x_idx+1,weights_y_idx));\n"
" weights2=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weights_y_idx));\n"
" weights3=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weights_y_idx));\n"
" weights4=RI_F(weights,SAMPLER,(int2)(weights_x_idx+0,weight_size+weights_y_idx));\n"
" weights5=RI_F(weights,SAMPLER,(int2)(weights_x_idx+1,weight_size+weights_y_idx));\n"
" weights6=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weight_size+weights_y_idx));\n"
" weights7=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weight_size+weights_y_idx++));\n"
"#endif\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights4,weights5,weights6,weights7);\n"
" \n"
" CALCULATE_OUTPUT(0);\n"
" CALCULATE_OUTPUT(1);\n"
" CALCULATE_OUTPUT(2);\n"
" CALCULATE_OUTPUT(3);\n"
" CALCULATE_OUTPUT_WEIGHTS4(4,0);\n"
" CALCULATE_OUTPUT_WEIGHTS4(5,1);\n"
" CALCULATE_OUTPUT_WEIGHTS4(6,2);\n"
" CALCULATE_OUTPUT_WEIGHTS4(7,3);\n"
" }\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(FLOAT4)0);\n"
" out1=fmax(out1,(FLOAT4)0);\n"
" out2=fmax(out2,(FLOAT4)0);\n"
" out3=fmax(out3,(FLOAT4)0);\n"
" out4=fmax(out4,(FLOAT4)0);\n"
" out5=fmax(out5,(FLOAT4)0);\n"
" out6=fmax(out6,(FLOAT4)0);\n"
" out7=fmax(out7,(FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n"
" out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n"
" out2=clamp(out2,(FLOAT4)0,(FLOAT4)6);\n"
" out3=clamp(out3,(FLOAT4)0,(FLOAT4)6);\n"
" out4=clamp(out4,(FLOAT4)0,(FLOAT4)6);\n"
" out5=clamp(out5,(FLOAT4)0,(FLOAT4)6);\n"
" out6=clamp(out6,(FLOAT4)0,(FLOAT4)6);\n"
" out7=clamp(out7,(FLOAT4)0,(FLOAT4)6);\n"
"#endif\n"
" const int out_x_base=mul24(out_channel_block_idx,output_shape.y);\n"
" const int out_y_base=mul24(out_batch_block_idx,output_shape.x);\n"
" int out_x_idx=out_width_block_idx;\n"
" int out_y_idx=out_height_block_idx << 2;\n"
" const int remain_y=output_shape.x-out_y_idx;\n"
" int output_idx=out_x_base+out_x_idx;\n"
" int output_idy=out_y_base+out_y_idx;\n"
" \n"
" if(remain_y >= 4){\n"
" WI_F(output,(int2)(output_idx,output_idy),out0);\n"
" WI_F(output,(int2)(output_idx,output_idy+1),out1);\n"
" WI_F(output,(int2)(output_idx,output_idy+2),out2);\n"
" WI_F(output,(int2)(output_idx,output_idy+3),out3);\n"
" }else if(remain_y == 3){\n"
" WI_F(output,(int2)(output_idx,output_idy),out0);\n"
" WI_F(output,(int2)(output_idx,output_idy+1),out1);\n"
" WI_F(output,(int2)(output_idx,output_idy+2),out2);\n"
" }else if(remain_y == 2){\n"
" WI_F(output,(int2)(output_idx,output_idy),out0);\n"
" WI_F(output,(int2)(output_idx,output_idy+1),out1);\n"
" }else if(remain_y == 1){\n"
" WI_F(output,(int2)(output_idx,output_idy),out0);\n"
" }\n"
" \n"
" if(out_channel_block_idx+1 >= out_channel_blocks) {\n"
" return;\n"
" }\n"
" output_idx += output_shape.y;\n"
" if(remain_y >= 4){\n"
" WI_F(output,(int2)(output_idx,output_idy),out4);\n"
" WI_F(output,(int2)(output_idx,output_idy+1),out5);\n"
" WI_F(output,(int2)(output_idx,output_idy+2),out6);\n"
" WI_F(output,(int2)(output_idx,output_idy+3),out7);\n"
" }else if(remain_y == 3){\n"
" WI_F(output,(int2)(output_idx,output_idy),out4);\n"
" WI_F(output,(int2)(output_idx,output_idy+1),out5);\n"
" WI_F(output,(int2)(output_idx,output_idy+2),out6);\n"
" }else if(remain_y == 2){\n"
" WI_F(output,(int2)(output_idx,output_idy),out4);\n"
" WI_F(output,(int2)(output_idx,output_idy+1),out5);\n"
" }else if(remain_y == 1){\n"
" WI_F(output,(int2)(output_idx,output_idy),out4);\n"
" }\n"
"}\n"
"__kernel\n"
"#if SET_ATTRIBUTE\n"
"__attribute__((work_group_size_hint(16,16,1)))\n"
"#endif\n"
"void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_BUFFER)\n"
" __global const FLOAT *weights,\n"
"#else\n"
" __read_only image2d_t weights,\n"
"#endif\n"
"#ifdef BIAS\n"
" __read_only image2d_t bias,\n"
"#endif\n"
" __write_only image2d_t output,\n"
" __private const int2 input_shape,\n"
" __private const int in_channel_block_length,\n"
" __private const int2 output_shape,\n"
" __private const int2 weights_shape,\n"
" __private const int2 stride_shape,\n"
" __private const int2 padding_shape,\n"
" __private const int2 dilation_shape,\n"
" __private const int out_width_blocks,\n"
" __private const int out_channel_blocks,\n"
" __private const int out_height_blocks\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" ,__private const int blockDim\n"
" ,__private const int inChannel\n"
"#endif\n"
") {\n"
" const int output_channel_width_idx=get_global_id(0);\n"
" const int output_batch_height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(output_channel_width_idx,output_batch_height_idx);\n"
" const int out_channel_block_idx=output_channel_width_idx/out_width_blocks;\n"
" const int out_width_block_idx=output_channel_width_idx % out_width_blocks;\n"
" const int out_height_block_idx=(output_batch_height_idx % out_height_blocks);\n"
" const int out_batch_block_idx=output_batch_height_idx/out_height_blocks;\n"
"#ifdef BIAS\n"
" FLOAT4 out0=RI_F(bias,SAMPLER,(int2)(out_channel_block_idx,0));\n"
"#else\n"
" FLOAT4 out0=(FLOAT4)0;\n"
"#endif\n"
" FLOAT4 out1=out0;\n"
" FLOAT4 out2=out0;\n"
" FLOAT4 out3=out0;\n"
" int in_width0=mad24(out_width_block_idx,stride_shape.y,-padding_shape.y);\n"
" int in_height0=mad24(out_height_block_idx,stride_shape.x<<2,-padding_shape.x);\n"
" int in_height1=in_height0+stride_shape.x;\n"
" int in_height2=in_height1+stride_shape.x;\n"
" int in_height3=in_height2+stride_shape.x;\n"
" int weight_size=mul24(weights_shape.y,weights_shape.x);\n"
" \n"
" const int weights_h_idx=mul24(out_channel_block_idx,weight_size);\n"
" const int batch_idx=mul24(out_batch_block_idx,input_shape.x);\n"
" \n"
" FLOAT4 in0,in1,in2,in3;\n"
" FLOAT4 weights0,weights1,weights2,weights3;\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n"
" const int weight_oc_offset=out_channel_blocks*weights_shape.x*weights_shape.y*4;\n"
"#endif\n"
" for (int in_channel_block_idx=0; in_channel_block_idx<in_channel_block_length; ++in_channel_block_idx) {\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int kindex=(in_channel_block_idx*4)/blockDim*out_channel_blocks*8;\n"
" COMPUTE_FLOAT8 ScaleOffset0=CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale0=(COMPUTE_FLOAT4)(ScaleOffset0.s0,ScaleOffset0.s2,ScaleOffset0.s4,ScaleOffset0.s6);\n"
" COMPUTE_FLOAT4 offset0=(COMPUTE_FLOAT4)(ScaleOffset0.s1,ScaleOffset0.s3,ScaleOffset0.s5,ScaleOffset0.s7);\n"
"#endif\n"
" const int in_idx=mul24(in_channel_block_idx,input_shape.y);\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n"
" int weight_offset=((((4*in_channel_block_idx+0)* out_channel_blocks+out_channel_block_idx) *weights_shape.x+0)*weights_shape.y+0)*4;\n"
"#else\n"
" int weights_x_idx=in_channel_block_idx << 2;\n"
" int weights_y_idx=weights_h_idx;\n"
"#endif\n"
" for (int iy=0; iy<weights_shape.x*dilation_shape.x; iy += dilation_shape.x) {\n"
" int h0=select(in_height0+iy+batch_idx,-1,(in_height0+iy<0 || in_height0+iy >= input_shape.x));\n"
" int h1=select(in_height1+iy+batch_idx,-1,(in_height1+iy<0 || in_height1+iy >= input_shape.x));\n"
" int h2=select(in_height2+iy+batch_idx,-1,(in_height2+iy<0 || in_height2+iy >= input_shape.x));\n"
" int h3=select(in_height3+iy+batch_idx,-1,(in_height3+iy<0 || in_height3+iy >= input_shape.x));\n"
" for (int ix=0; ix<weights_shape.y*dilation_shape.y; ix += dilation_shape.y) {\n"
" int w0=select(in_width0+ix+in_idx,-1,(in_width0+ix<0 || in_width0+ix >= input_shape.y));\n"
" \n"
" in0=RI_F(input,SAMPLER,(int2)(w0,h0));\n"
" in1=RI_F(input,SAMPLER,(int2)(w0,h1));\n"
" in2=RI_F(input,SAMPLER,(int2)(w0,h2));\n"
" in3=RI_F(input,SAMPLER,(int2)(w0,h3));\n"
" \n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" char4 charWeight0=vload4(0,kernel_ptr+weight_offset);\n"
" char4 charWeight1=vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n"
" char4 charWeight2=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*2);\n"
" char4 charWeight3=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*3);\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" weight_offset += 4;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" uchar2 charWeightInt40=vload2(0,kernel_ptr+weight_offset/2);\n"
" uchar2 charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2);\n"
" uchar2 charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*2/2);\n"
" uchar2 charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*3/2);\n"
" char4 charWeight0=(char4)(0,0,0,0);\n"
" char4 charWeight1=(char4)(0,0,0,0);\n"
" char4 charWeight2=(char4)(0,0,0,0);\n"
" char4 charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" weight_offset += 4;\n"
"#elif (defined USE_BUFFER)\n"
" weights0=vload4(0,weights+weight_offset);\n"
" weights1=vload4(0,weights+weight_offset+weight_oc_offset);\n"
" weights2=vload4(0,weights+weight_offset+weight_oc_offset*2);\n"
" weights3=vload4(0,weights+weight_offset+weight_oc_offset*3);\n"
" weight_offset += 4;\n"
"#else\n"
" weights0=RI_F(weights,SAMPLER,(int2)(weights_x_idx+0,weights_y_idx));\n"
" weights1=RI_F(weights,SAMPLER,(int2)(weights_x_idx+1,weights_y_idx));\n"
" weights2=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weights_y_idx));\n"
" weights3=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weights_y_idx++));\n"
"#endif\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n"
" CALCULATE_OUTPUT(0);\n"
" CALCULATE_OUTPUT(1);\n"
" CALCULATE_OUTPUT(2);\n"
" CALCULATE_OUTPUT(3);\n"
" }\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(FLOAT4)0);\n"
" out1=fmax(out1,(FLOAT4)0);\n"
" out2=fmax(out2,(FLOAT4)0);\n"
" out3=fmax(out3,(FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n"
" out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n"
" out2=clamp(out2,(FLOAT4)0,(FLOAT4)6);\n"
" out3=clamp(out3,(FLOAT4)0,(FLOAT4)6);\n"
"#endif\n"
" const int out_x_base=mul24(out_channel_block_idx,output_shape.y);\n"
" const int out_y_base=mul24(out_batch_block_idx,output_shape.x);\n"
" int out_x_idx=out_width_block_idx;\n"
" int out_y_idx=out_height_block_idx << 2;\n"
" const int remain_y=output_shape.x-out_y_idx;\n"
" int output_idx=out_x_base+out_x_idx;\n"
" int output_idy=out_y_base+out_y_idx;\n"
" if(remain_y >= 4){\n"
" WI_F(output,(int2)(output_idx,output_idy),out0);\n"
" WI_F(output,(int2)(output_idx,output_idy+1),out1);\n"
" WI_F(output,(int2)(output_idx,output_idy+2),out2);\n"
" WI_F(output,(int2)(output_idx,output_idy+3),out3);\n"
" }else if(remain_y == 3){\n"
" WI_F(output,(int2)(output_idx,output_idy),out0);\n"
" WI_F(output,(int2)(output_idx,output_idy+1),out1);\n"
" WI_F(output,(int2)(output_idx,output_idy+2),out2);\n"
" }else if(remain_y == 2){\n"
" WI_F(output,(int2)(output_idx,output_idy),out0);\n"
" WI_F(output,(int2)(output_idx,output_idy+1),out1);\n"
" }else{\n"
" WI_F(output,(int2)(output_idx,output_idy),out0);\n"
" }\n"
"}\n"
;
const char* deconv_2d = 
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void deconv_2d(GLOBAL_SIZE_3_DIMS\n"
" #ifdef USE_BUFFER\n"
" __global FLOAT* input,\n"
" __global FLOAT* weights,\n"
" #ifdef BIAS\n"
" __global FLOAT* bias,\n"
" #endif\n"
" __global FLOAT* output,__private const int batch,\n"
" #else\n"
" __read_only image2d_t input,\n"
" __read_only image2d_t weights,\n"
" #ifdef BIAS\n"
" __read_only image2d_t bias,\n"
" #endif\n"
" __write_only image2d_t output,\n"
" #endif\n"
" __private const int2 input_shape,\n"
" __private const int2 output_shape,\n"
" __private const int2 stride_shape,\n"
" __private const int2 align_shape,\n"
" __private const int2 padding_shape,\n"
" __private const int2 kernel_shape,\n"
" __private const int kernel_size,\n"
" __private const int in_channel_blocks,__private const int out_channel_blocks) {\n"
" const int out_channel_blocks_idx=get_global_id(0);\n"
" const int out_w_idx=get_global_id(1);\n"
" const int out_batch_height_idx=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(out_channel_blocks_idx,out_w_idx,out_batch_height_idx);\n"
"#ifdef BIAS\n"
" #ifdef USE_BUFFER\n"
" FLOAT4 out0=vload4(out_channel_blocks_idx,bias);\n"
" #else\n"
" FLOAT4 out0=RI_F(bias,SAMPLER,(int2)(out_channel_blocks_idx,0));\n"
" #endif\n"
"#else\n"
" FLOAT4 out0=(FLOAT4)0;\n"
"#endif\n"
" const int out_b_idx=out_batch_height_idx/output_shape.x;\n"
" const int out_h_idx=out_batch_height_idx % output_shape.x;\n"
" \n"
" int kernel_start_x=max(0,(out_w_idx+align_shape.y)/stride_shape.y);\n"
" int kernel_start_y=max(0,(out_h_idx+align_shape.x)/stride_shape.x);\n"
" int deal_kernel_width=kernel_shape.y-mad24(kernel_start_x,stride_shape.y,padding_shape.y)+out_w_idx-1;\n"
" int deal_kernel_height=kernel_shape.x-mad24(kernel_start_y,stride_shape.x,padding_shape.x)+out_h_idx-1;\n"
" \n"
" \n"
" int kernel_x_0,kernel_x_1,kernel_x_2,kernel_x_3,kernel_y;\n"
" FLOAT4 in0;\n"
" FLOAT4 weights0,weights1,weights2,weights3;\n"
" for (int ic=0; ic<in_channel_blocks; ic++) {\n"
" kernel_x_0=ic << 2;\n"
" kernel_x_1=kernel_x_0+1;\n"
" kernel_x_2=kernel_x_0+2;\n"
" kernel_x_3=kernel_x_0+3;\n"
" for (int k_y=deal_kernel_height,idx_h=kernel_start_y; k_y >= 0; k_y -= stride_shape.x,idx_h++) {\n"
" #ifdef USE_BUFFER\n"
" int in_width0=kernel_start_x;\n"
" for (int k_x=deal_kernel_width; k_x >= 0; k_x -= stride_shape.y) {\n"
" kernel_y=mad24(k_y,kernel_shape.y,k_x);\n"
" kernel_y=mad24(out_channel_blocks_idx,kernel_size,kernel_y);\n"
" //weights NC4HW4 [1,4*icC4,ocC4*kh*kw,1] xic4\n"
" //index: [0,kernel_x_0,kernel_y,0]\n"
" weights0=vload4(kernel_x_0*(out_channel_blocks*kernel_shape.x*kernel_shape.y)+kernel_y,weights);\n"
" weights1=vload4(kernel_x_1*(out_channel_blocks*kernel_shape.x*kernel_shape.y)+kernel_y,weights);\n"
" weights2=vload4(kernel_x_2*(out_channel_blocks*kernel_shape.x*kernel_shape.y)+kernel_y,weights);\n"
" weights3=vload4(kernel_x_3*(out_channel_blocks*kernel_shape.x*kernel_shape.y)+kernel_y,weights);\n"
" bool outBoundry=(idx_h<0 || idx_h >= input_shape.x || kernel_start_x<0 || in_width0 >= input_shape.y);\n"
" int inp_offset=(((out_b_idx+ic*batch)*input_shape.x+idx_h)*input_shape.y+in_width0)*4;\n"
" in0=outBoundry ? (FLOAT4)0 : vload4(0,input+inp_offset);\n"
" out0=mad(in0.x,weights0,out0);\n"
" out0=mad(in0.y,weights1,out0);\n"
" out0=mad(in0.z,weights2,out0);\n"
" out0=mad(in0.w,weights3,out0);\n"
" in_width0++;\n"
" }\n"
" #else\n"
" int in_idy=mad24(out_b_idx,input_shape.x,idx_h);\n"
" int in_hb_value=select(in_idy,-1,idx_h<0 || idx_h >= input_shape.x);\n"
" int in_width0=kernel_start_x;\n"
" for (int k_x=deal_kernel_width; k_x >= 0; k_x -= stride_shape.y) {\n"
" kernel_y=mad24(k_y,kernel_shape.y,k_x);\n"
" kernel_y=mad24(out_channel_blocks_idx,kernel_size,kernel_y);\n"
" weights0=RI_F(weights,SAMPLER,(int2)(kernel_x_0,kernel_y));\n"
" weights1=RI_F(weights,SAMPLER,(int2)(kernel_x_1,kernel_y));\n"
" weights2=RI_F(weights,SAMPLER,(int2)(kernel_x_2,kernel_y));\n"
" weights3=RI_F(weights,SAMPLER,(int2)(kernel_x_3,kernel_y));\n"
" int in_idx=mul24(ic,input_shape.y);\n"
" int in_width_value0 = in_width0; "" in_width_value0 = "" select(in_idx + in_width_value0, -1, (in_width_value0 < 0 || in_width_value0 >= input_shape.y)); "" in0=RI_F(input,SAMPLER,(int2)(in_width_value0,in_hb_value));\n"
" out0=mad(in0.x,weights0,out0);\n"
" out0=mad(in0.y,weights1,out0);\n"
" out0=mad(in0.z,weights2,out0);\n"
" out0=mad(in0.w,weights3,out0);\n"
" in_width0++;\n"
" }\n"
" #endif\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n"
"#endif\n"
"#ifdef USE_BUFFER\n"
" const int out_offset=(((out_b_idx+out_channel_blocks_idx*batch)*output_shape.x+out_h_idx)*output_shape.y+out_w_idx)*4;\n"
" vstore4(out0,0,output+out_offset);\n"
"#else\n"
" int out_image_width_idx=mad24(out_channel_blocks_idx,output_shape.y,out_w_idx);\n"
" WI_F(output,(int2)(out_image_width_idx,out_batch_height_idx),out0);\n"
"#endif\n"
"}\n"
"__kernel void iohw2oihw(__global const float* input_ptr,__global float* output_ptr,int plane_number,int input_channel,int output_channel) {\n"
" const int ic_index=get_global_id(0),oc_index=get_global_id(1);\n"
" if (ic_index >= input_channel || oc_index >= output_channel) {\n"
" return;\n"
" }\n"
" const int input_offset=(ic_index*output_channel+oc_index)*plane_number;\n"
" const int output_offset=(oc_index*input_channel+ic_index)*plane_number;\n"
" for (int i=0; i<plane_number; ++i) {\n"
" output_ptr[output_offset+i]=input_ptr[input_offset+i];\n"
" }\n"
"}\n"
;
const char* unary = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"inline float4 gelu(float4 in){\n"
" float4 value=0.79788458f*(0.044715f*in*in*in+in);\n"
" float4 x2=value*value;\n"
" float4 dst=value>(float4)5.0f ? (float4)1.0f : (value <= -(float4)5.0f ? -(float4)1.0f :\n"
" (value*(135135.0f+x2*(17325.0f+x2*(378.0f+x2))))/(135135.0f+x2*(62370.0f+x2*(3150.0f+x2*28.0f))));\n"
" return (1.0f+dst)*in*0.5f;\n"
"}\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void unary(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,__write_only image2d_t output) {\n"
" const int channel_block_idx=get_global_id(0);\n"
" const int w=get_global_id(1);\n"
" const int hb=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(channel_block_idx,w,hb);\n"
" const int width=global_size_dim1;\n"
" const int pos=mad24(channel_block_idx,width,w);\n"
" float4 in=convert_float4(RI_DATA(input,SAMPLER,(int2)(pos,hb)));\n"
" OUTPUT_TYPE_I4 out=CONVERT_OUTPUT_I4(OPERATOR);\n"
" \n"
" WI_DATA(output,(int2)(pos,hb),out);\n"
"}\n"
;
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* grid_sample_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"enum BorderMode {\n"
" BorderMode_ZEROS=0,\n"
" BorderMode_CLAMP=1,\n"
" BorderMode_REFLECTION=2,\n"
" BorderMode_MIN=BorderMode_ZEROS,\n"
" BorderMode_MAX=BorderMode_REFLECTION\n"
"};\n"
"float getPosition(float x,int range,int alignCorners){\n"
" float a=alignCorners == 1? 1.0f : 0.0f;\n"
" float b=alignCorners == 1? 0.0f : 1.0f;\n"
" return ((1.0f+x)*(range-a)-b)/2.0f;\n"
"}\n"
"static int CLAMP(int v,int min,int max) {\n"
" if ((v)<min) {\n"
" (v)=min;\n"
" } else if ((v)>max) {\n"
" (v)=max;\n"
" }\n"
" return v;\n"
"}\n"
"COMPUTE_FLOAT4 sample(int h,int w,\n"
" const int offset_base,\n"
" __global const FLOAT *buffer,\n"
" int height,int width,\n"
" enum BorderMode paddingMode){\n"
" if (h<0 || h >= height || w<0 || w >= width) {\n"
" if(paddingMode == BorderMode_ZEROS)\n"
" {\n"
" return 0.0f;\n"
" }\n"
" // Clearly,CLAMP is the right way to go for GridSamplePaddingMode_BORDER\n"
" // For GridSamplePaddingMode_REFLECTION,since we have reflected the values into (-1,1),\n"
" // the leftover reflections degrade to GridSamplePaddingMode_BORDER\n"
" h=CLAMP(h,0,height-1);\n"
" w=CLAMP(w,0,width-1);\n"
" }\n"
" int offset=(offset_base+h)*width+w;\n"
" return CONVERT_COMPUTE_FLOAT4(vload4(offset,buffer));\n"
"}\n"
"__kernel void nearest_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT* input,\n"
" __global const FLOAT* grid,\n"
" __global FLOAT* output,\n"
" __private const int input_height,\n"
" __private const int input_width,\n"
" __private const int output_height,\n"
" __private const int output_width,\n"
" __private const int batch,\n"
" __private const enum BorderMode paddingMode,\n"
" __private const int alignCorners){\n"
" \n"
" const int output_channel_block_idx=get_global_id(0);\n"
" const int output_width_block_idx=get_global_id(1);\n"
" const int output_batch_height_block_idx=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(output_channel_block_idx,output_width_block_idx,output_batch_height_block_idx);\n"
" const int output_batch_idx=output_batch_height_block_idx/output_height;\n"
" const int output_height_idx=output_batch_height_block_idx % output_height;\n"
" // grid data format has been converted from nchw to nc4hw4\n"
" /* \n"
" (x1,x1,x1,x1) (y1,y2,y3,y4) \n"
" . . \n"
" . . slice\n"
" (x1,y1)...(xn,y1) . . \n"
" . . (xn,xn,xn,xn) (y1,y2,y3,y4)\n"
" . . <-> ---------------------------\n"
" . . (x1,x1,x1,x1) (y5,y6,y7,y8)\n"
" (x1,ym)...(xn,ym) . .\n"
" . . slice\n"
" . .\n"
" (xn,xn,xn,xn) (y5,y6,y7,y8)\n"
" ---------------------------\n"
" */\n"
" // output_width_block_idx means gird y offset,2 means grid width\n"
" const int grid_offset=(output_batch_idx*output_height+output_height_idx)*output_width+output_width_block_idx;\n"
" COMPUTE_FLOAT2 grid_xy=CONVERT_COMPUTE_FLOAT2(vload2(grid_offset,grid));\n"
" // get grid x,y\n"
" const float x=(float)grid_xy.x;\n"
" const float y=(float)grid_xy.y;\n"
" // convert grid x,y to input x,y coordinate range\n"
" float in_grid_x=getPosition(x,input_width,alignCorners);\n"
" float in_grid_y=getPosition(y,input_height,alignCorners);\n"
" // get nearest point\n"
" int nw=floor(in_grid_x+0.5f);\n"
" int nh=floor(in_grid_y+0.5f);\n"
" const int inp_offset_base=(output_batch_idx+output_channel_block_idx*batch)*input_height;\n"
" COMPUTE_FLOAT4 value=sample(nh,nw,inp_offset_base,input,input_height,input_width,paddingMode);\n"
" const int output_offset=((output_batch_idx+output_channel_block_idx*batch)*output_height+output_height_idx)*output_width+output_width_block_idx;\n"
" vstore4(CONVERT_FLOAT4(value),output_offset,output);\n"
"}\n"
"__kernel void bilinear_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT* input,\n"
" __global const FLOAT* grid,\n"
" __global FLOAT* output,\n"
" __private const int input_height,\n"
" __private const int input_width,\n"
" __private const int output_height,\n"
" __private const int output_width,\n"
" __private const int batch,\n"
" __private const enum BorderMode paddingMode,\n"
" __private const int alignCorners){\n"
" const int output_channel_block_idx=get_global_id(0);\n"
" const int output_width_block_idx=get_global_id(1);\n"
" const int output_batch_height_block_idx=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(output_channel_block_idx,output_width_block_idx,output_batch_height_block_idx);\n"
" const int output_batch_idx=output_batch_height_block_idx/output_height;\n"
" const int output_height_idx=output_batch_height_block_idx % output_height;\n"
" // output_width_block_idx means gird y offset,2 means grid width\n"
" const int grid_offset=(output_batch_idx*output_height+output_height_idx)*output_width+output_width_block_idx;\n"
" COMPUTE_FLOAT2 grid_xy=CONVERT_COMPUTE_FLOAT2(vload2(grid_offset,grid));\n"
" \n"
" // get grid x,y\n"
" const float x=(float)grid_xy.x;\n"
" const float y=(float)grid_xy.y;\n"
" // convert grid x,y to input x,y coordinate range\n"
" float in_grid_x=getPosition(x,input_width,alignCorners);\n"
" float in_grid_y=getPosition(y,input_height,alignCorners);\n"
" int in_h0=floor(in_grid_y);\n"
" int in_w0=floor(in_grid_x);\n"
" int in_h1=ceil(in_grid_y);\n"
" int in_w1=ceil(in_grid_x);\n"
" float x_weight=in_w1-in_grid_x;\n"
" float y_weight=in_h1-in_grid_y;\n"
" // bilinear interpolation\n"
" const int inp_offset_base=(output_batch_idx+output_channel_block_idx*batch)*input_height;\n"
" COMPUTE_FLOAT4 i00=sample(in_h0,in_w0,inp_offset_base,input,input_height,input_width,paddingMode);\n"
" COMPUTE_FLOAT4 i01=sample(in_h0,in_w1,inp_offset_base,input,input_height,input_width,paddingMode);\n"
" COMPUTE_FLOAT4 i10=sample(in_h1,in_w0,inp_offset_base,input,input_height,input_width,paddingMode);\n"
" COMPUTE_FLOAT4 i11=sample(in_h1,in_w1,inp_offset_base,input,input_height,input_width,paddingMode);\n"
" COMPUTE_FLOAT4 value=CONVERT_COMPUTE_FLOAT4(((COMPUTE_FLOAT4)x_weight*CONVERT_COMPUTE_FLOAT4(i00)+(COMPUTE_FLOAT4)(1.0f-x_weight)*CONVERT_COMPUTE_FLOAT4(i01))*(COMPUTE_FLOAT4)y_weight +\n"
" ((COMPUTE_FLOAT4)x_weight*CONVERT_COMPUTE_FLOAT4(i10)+(COMPUTE_FLOAT4)(1.0f-x_weight)*CONVERT_COMPUTE_FLOAT4(i11))*(COMPUTE_FLOAT4)(1.0f- y_weight));\n"
" \n"
" const int output_offset=((output_batch_idx+output_channel_block_idx*batch)*output_height+output_height_idx)*output_width+output_width_block_idx;\n"
" vstore4(CONVERT_FLOAT4(value),output_offset,output);\n"
"}\n"
;
#endif
const char* interp = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void interp(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,__write_only image2d_t output,\n"
" __private const float height_scale,__private const float width_scale,\n"
" __private const float height_offset,__private const float width_offset,\n"
" __private const int input_height,__private const int input_width,\n"
" __private const int out_height) {\n"
" const int output_channel_block_idx=get_global_id(0);\n"
" const int output_width_block_idx=get_global_id(1);\n"
" const int output_batch_height_block_idx=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(output_channel_block_idx,output_width_block_idx,output_batch_height_block_idx);\n"
" const int output_channel_block_idxs=global_size_dim0;\n"
" const int output_width=global_size_dim1;\n"
" const int output_batch_idx=output_batch_height_block_idx/out_height;\n"
" const int output_height_idx=output_batch_height_block_idx % out_height;\n"
" const float scale_height=output_height_idx*height_scale+height_offset;\n"
" const float scale_width=output_width_block_idx*width_scale+width_offset;\n"
"#define CLAMP(val,min_val,max_val) max(min(val,max_val),min_val)\n"
" const int height_floor=(int)floor(scale_height);\n"
" const int height_lf=CLAMP(height_floor,0,input_height-1);\n"
" const int height_uf=CLAMP(height_floor+1,0,input_height-1);\n"
" \n"
" const int width_floor=(int)floor(scale_width);\n"
" const int width_lf=CLAMP(width_floor,0,input_width-1);\n"
" const int width_uf=CLAMP(width_floor+1,0,input_width-1);\n"
" const float height_gap=scale_height-height_floor;\n"
" const float width_gap=scale_width-width_floor;\n"
" const int input_width_offset=mul24(output_channel_block_idx,input_width);\n"
" const int input_height_offset=mul24(output_batch_idx,input_height);\n"
" float4 top_left =\n"
" read_imagef(input,SAMPLER,(int2)(input_width_offset+width_lf,input_height_offset+height_lf));\n"
" float4 top_right =\n"
" read_imagef(input,SAMPLER,(int2)(input_width_offset+width_uf,input_height_offset+height_lf));\n"
" float4 bottom_left =\n"
" read_imagef(input,SAMPLER,(int2)(input_width_offset+width_lf,input_height_offset+height_uf));\n"
" float4 bottom_right =\n"
" read_imagef(input,SAMPLER,(int2)(input_width_offset+width_uf,input_height_offset+height_uf));\n"
" float4 top=mad((top_right-top_left),width_gap,top_left);\n"
" float4 bottom=mad((bottom_right-bottom_left),width_gap,bottom_left);\n"
" float4 out=mad((bottom-top),height_gap,top);\n"
" const int out_image_w=mad24(output_channel_block_idx,output_width,output_width_block_idx);\n"
" const int out_image_h=mad24(output_batch_idx,out_height,output_height_idx);\n"
" write_imagef(output,(int2)(out_image_w,out_image_h),out);\n"
"}\n"
;
const char* select = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_2_DIMS ""__private const int global_size_dim0,__private const int global_size_dim1,\n"
"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void select_img(GLOBAL_SIZE_2_DIMS\n"
" __read_only image2d_t input,\n"
" __read_only image2d_t input0,\n"
" __read_only image2d_t input1,\n"
" __write_only image2d_t output\n"
" ) {\n"
" const int idx=get_global_id(0);\n"
" const int idy=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(idx,idy);\n"
" int4 select_vec=read_imagei(input,SAMPLER,(int2)(idx,idy));\n"
"#ifdef INSIZE1_EUQAL_1\n"
" FLOAT4 in0=RI_F(input0,SAMPLER,(int2)(0,0));\n"
" in0=(FLOAT4)(in0.x);\n"
"#else\n"
" FLOAT4 in0=RI_F(input0,SAMPLER,(int2)(idx,idy));\n"
"#endif\n"
" \n"
"#ifdef INSIZE2_EUQAL_1\n"
" FLOAT4 in1=RI_F(input1,SAMPLER,(int2)(0,0));\n"
" in1=(FLOAT4)(in1.x);\n"
"#else\n"
" FLOAT4 in1=RI_F(input1,SAMPLER,(int2)(idx,idy));\n"
"#endif\n"
" FLOAT4 out=select(in1,in0,select_vec == (int4)1);\n"
" WI_F(output,(int2)(idx,idy),out);\n"
"}\n"
;
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* range_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_2_DIMS ""__private const int global_size_dim0,__private const int global_size_dim1,\n"
"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n"
"__kernel void range_buf(GLOBAL_SIZE_2_DIMS\n"
" __global const INPUT_TYPE* input0,\n"
" __global const INPUT_TYPE* input2,\n"
" __global OUTPUT_TYPE* output,\n"
" __private const int size\n"
" ) {\n"
" const int x=get_global_id(0);\n"
" const int y=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(x,y);\n"
" \n"
" int index=x << 2;\n"
" int4 index4=(int4)(index,index+1,index+2,index+3);\n"
" INPUT_TYPE start=input0[0];\n"
" INPUT_TYPE step=input2[0];\n"
" OUTPUT_TYPE4 value=(OUTPUT_TYPE4)start+CONVERT_OUTPUT4(index4)*(OUTPUT_TYPE4)step;\n"
"#ifdef PACK_LEAVE\n"
" if(index+3 >= size){\n"
" OUTPUT_TYPE* value_ptr=(OUTPUT_TYPE*)&value;\n"
" for(int i=0; i<size-index; ++i){\n"
" output[index+i]=value_ptr[i];\n"
" }\n"
" }else{\n"
"#endif\n"
" vstore4(value,0,output+index);\n"
"#ifdef PACK_LEAVE\n"
" }\n"
"#endif\n"
"}\n"
;
#endif
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* self_attention_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"#define DEAL_HEAD_DIM_NOT_ALIGN "" if(hd * 4 + 3 >= head_dim) {"" temp_0.w = (FLOAT)0;"" temp_1.w = (FLOAT)0;"" temp_2.w = (FLOAT)0;"" temp_3.w = (FLOAT)0;"" }"" if(hd * 4 + 2 >= head_dim) {"" temp_0.z = (FLOAT)0;"" temp_1.z = (FLOAT)0;"" temp_2.z = (FLOAT)0;"" temp_3.z = (FLOAT)0;"" }"" if(hd * 4 + 1 >= head_dim) {"" temp_0.y = (FLOAT)0;"" temp_1.y = (FLOAT)0;"" temp_2.y = (FLOAT)0;"" temp_3.y = (FLOAT)0;"" }\n"
"#define DEAL_SEQ_LEN_NOT_ALIGN "" if(4 * sl + 3 >= seq_len) {"" temp_3 = (FLOAT4)0;"" }"" if(4 * sl + 2 >= seq_len) {"" temp_2 = (FLOAT4)0;"" }"" if(4 * sl + 1 >= seq_len) {"" temp_1 = (FLOAT4)0;"" }\n"
"__kernel void split_transpose_qkv(GLOBAL_SIZE_3_DIMS\n"
" __global const FLOAT *input,// [Batch,seqLen/4,mNumHead*3*mHeadDim,4]\n"
" __global FLOAT *output_q,// [Batch*mNumHead,head_dim_pack_k,seq_len_pack_mn/qSeqSplitNum]\n"
" __global FLOAT *output_k,// [Batch*mNumHead,head_dim_pack_k,seq_len_pack_mn]\n"
" __global FLOAT *output_v,// [Batch*mNumHead,ROUND_UP(seqLen,tile),head_dim_pack_mn]\n"
" __private const int seq_len_pack_mn,\n"
" __private const int seq_len_piece,\n"
" __private const int head_dim_pack_mn,\n"
" __private const int head_dim_pack_k,\n"
" __private const int seq_len,\n"
" __private const int head_num,\n"
" __private const int head_dim,\n"
" __private const int batch,\n"
" __private const int seq_index\n"
") {\n"
" const int sl=get_global_id(0); // seqLen_4\n"
" const int hd=get_global_id(1); // mHeadDim_4\n"
" const int z=get_global_id(2); // Batch*mNumHead\n"
" DEAL_NON_UNIFORM_DIM3(sl,hd,z);\n"
" \n"
" const int b=z/head_num;\n"
" const int hn=z % head_num;\n"
" \n"
" const int seq_len_4=(seq_len+3)/4;\n"
" const int offset_q=((b*head_num+hn)*head_dim_pack_k+4*hd)*seq_len_piece+4*sl;\n"
" if(seq_index>0) {\n"
" // fill output_q only\n"
" if(sl*4 >= seq_len || hd*4 >= head_dim) {\n"
" if(hd*4<head_dim_pack_k) {\n"
" if(sl*4<seq_len_piece) {\n"
" vstore4((FLOAT4)0,0,output_q+offset_q);\n"
" vstore4((FLOAT4)0,0,output_q+offset_q+seq_len_piece);\n"
" vstore4((FLOAT4)0,0,output_q+offset_q+seq_len_piece+seq_len_piece);\n"
" vstore4((FLOAT4)0,0,output_q+offset_q+seq_len_piece+seq_len_piece+seq_len_piece);\n"
" }\n"
" }\n"
" return;\n"
" }\n"
" \n"
" const int offset_inp=((((seq_index*seq_len_piece/4+sl)*batch+b)*head_num+hn)*3*head_dim+4*hd)*4;\n"
" if(sl*4<seq_len_piece) {\n"
" FLOAT4 temp_0=vload4(0,input+offset_inp);\n"
" FLOAT4 temp_1=vload4(0,input+offset_inp+4);\n"
" FLOAT4 temp_2=vload4(0,input+offset_inp+8);\n"
" FLOAT4 temp_3=vload4(0,input+offset_inp+12);\n"
" #ifdef HEADDIM_LEAVE\n"
" DEAL_HEAD_DIM_NOT_ALIGN\n"
" #endif\n"
" #ifdef SEQLEN_LEAVE\n"
" DEAL_SEQ_LEN_NOT_ALIGN\n"
" #endif\n"
" vstore4(temp_0,0,output_q+offset_q);\n"
" vstore4(temp_1,0,output_q+offset_q+seq_len_piece);\n"
" vstore4(temp_2,0,output_q+offset_q+seq_len_piece+seq_len_piece);\n"
" vstore4(temp_3,0,output_q+offset_q+seq_len_piece+seq_len_piece+seq_len_piece);\n"
" }\n"
" return;\n"
" }\n"
" const int offset_k=((b*head_num+hn)*head_dim_pack_k+4*hd)*seq_len_pack_mn+4*sl;\n"
" const int offset_v=((b*head_num+hn)*seq_len_pack_mn+4*sl)*head_dim_pack_mn+4*hd;\n"
" if(sl*4 >= seq_len || hd*4 >= head_dim) {\n"
" if(hd*4<head_dim_pack_k) {\n"
" if(sl*4<seq_len_piece) {\n"
" vstore4((FLOAT4)0,0,output_q+offset_q);\n"
" vstore4((FLOAT4)0,0,output_q+offset_q+seq_len_piece);\n"
" vstore4((FLOAT4)0,0,output_q+offset_q+seq_len_piece+seq_len_piece);\n"
" vstore4((FLOAT4)0,0,output_q+offset_q+seq_len_piece+seq_len_piece+seq_len_piece);\n"
" }\n"
" vstore4((FLOAT4)0,0,output_k+offset_k);\n"
" vstore4((FLOAT4)0,0,output_k+offset_k+seq_len_pack_mn);\n"
" vstore4((FLOAT4)0,0,output_k+offset_k+seq_len_pack_mn+seq_len_pack_mn);\n"
" vstore4((FLOAT4)0,0,output_k+offset_k+seq_len_pack_mn+seq_len_pack_mn+seq_len_pack_mn);\n"
" }\n"
" vstore4((FLOAT4)0,0,output_v+offset_v);\n"
" vstore4((FLOAT4)0,0,output_v+offset_v+head_dim_pack_mn);\n"
" vstore4((FLOAT4)0,0,output_v+offset_v+head_dim_pack_mn+head_dim_pack_mn);\n"
" vstore4((FLOAT4)0,0,output_v+offset_v+head_dim_pack_mn+head_dim_pack_mn+head_dim_pack_mn);\n"
" \n"
" return;\n"
" }\n"
" \n"
" const int offset_inp=(((sl*batch+b)*head_num+hn)*3*head_dim+4*hd)*4;\n"
" \n"
" if(sl*4<seq_len_piece) {\n"
" FLOAT4 temp_0=vload4(0,input+offset_inp);\n"
" FLOAT4 temp_1=vload4(0,input+offset_inp+4);\n"
" FLOAT4 temp_2=vload4(0,input+offset_inp+8);\n"
" FLOAT4 temp_3=vload4(0,input+offset_inp+12);\n"
" #ifdef HEADDIM_LEAVE\n"
" DEAL_HEAD_DIM_NOT_ALIGN\n"
" #endif\n"
" #ifdef SEQLEN_LEAVE\n"
" DEAL_SEQ_LEN_NOT_ALIGN\n"
" #endif\n"
" vstore4(temp_0,0,output_q+offset_q);\n"
" vstore4(temp_1,0,output_q+offset_q+seq_len_piece);\n"
" vstore4(temp_2,0,output_q+offset_q+seq_len_piece+seq_len_piece);\n"
" vstore4(temp_3,0,output_q+offset_q+seq_len_piece+seq_len_piece+seq_len_piece);\n"
" }\n"
" \n"
" {\n"
" // K\n"
" FLOAT4 temp_0=vload4(0,input+offset_inp+4*head_dim);\n"
" FLOAT4 temp_1=vload4(0,input+offset_inp+4*head_dim+4);\n"
" FLOAT4 temp_2=vload4(0,input+offset_inp+4*head_dim+8);\n"
" FLOAT4 temp_3=vload4(0,input+offset_inp+4*head_dim+12);\n"
" #ifdef HEADDIM_LEAVE\n"
" DEAL_HEAD_DIM_NOT_ALIGN\n"
" #endif\n"
" #ifdef SEQLEN_LEAVE\n"
" DEAL_SEQ_LEN_NOT_ALIGN\n"
" #endif\n"
" \n"
" vstore4(temp_0,0,output_k+offset_k);\n"
" vstore4(temp_1,0,output_k+offset_k+seq_len_pack_mn);\n"
" vstore4(temp_2,0,output_k+offset_k+seq_len_pack_mn+seq_len_pack_mn);\n"
" vstore4(temp_3,0,output_k+offset_k+seq_len_pack_mn+seq_len_pack_mn+seq_len_pack_mn);\n"
" \n"
" // V\n"
" temp_0=vload4(0,input+offset_inp+8*head_dim);\n"
" temp_1=vload4(0,input+offset_inp+8*head_dim+4);\n"
" temp_2=vload4(0,input+offset_inp+8*head_dim+8);\n"
" temp_3=vload4(0,input+offset_inp+8*head_dim+12);\n"
" #ifdef HEADDIM_LEAVE\n"
" DEAL_HEAD_DIM_NOT_ALIGN\n"
" #endif\n"
" #ifdef SEQLEN_LEAVE\n"
" DEAL_SEQ_LEN_NOT_ALIGN\n"
" #endif\n"
" \n"
" vstore4((FLOAT4){temp_0.x,temp_1.x,temp_2.x,temp_3.x},0,output_v+offset_v);\n"
" vstore4((FLOAT4){temp_0.y,temp_1.y,temp_2.y,temp_3.y},0,output_v+offset_v+head_dim_pack_mn);\n"
" vstore4((FLOAT4){temp_0.z,temp_1.z,temp_2.z,temp_3.z},0,output_v+offset_v+head_dim_pack_mn+head_dim_pack_mn);\n"
" vstore4((FLOAT4){temp_0.w,temp_1.w,temp_2.w,temp_3.w},0,output_v+offset_v+head_dim_pack_mn+head_dim_pack_mn+head_dim_pack_mn);\n"
" }\n"
"}\n"
"#ifndef SOFTMAX_LOCAL_SIZE\n"
" #define SOFTMAX_LOCAL_SIZE 512\n"
"#endif\n"
"// [outside,axis,inside] -> reduce: inside\n"
"__kernel void softmax_inside(GLOBAL_SIZE_3_DIMS\n"
" __global const FLOAT *input,// [batch*mNumHead,ROUND_UP(seqLen,tile),ROUND_UP(seqLen,tile)]\n"
" __global FLOAT *output,\n"
" __private const int inside_len,\n"
" __private const int4 shape // [batch*mNumHead,ROUND_UP(seqLen,tile),ROUND_UP(seqLen,tile)]\n"
" ) {\n"
" const int inside=get_global_id(0);\n"
" const int axis=get_global_id(1);\n"
" const int outside=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(inside,axis,outside);\n"
" const int offset=(outside*shape.y+axis)*shape.z+0;\n"
" int lid=get_local_id(0);\n"
" float local sum[SOFTMAX_LOCAL_SIZE];\n"
" /*Compute Max */\n"
" float maxValue=(float)(-FLT_MAX);\n"
" // clip to seq_len\n"
" for (int i=lid; i<inside_len; i+=SOFTMAX_LOCAL_SIZE) {\n"
" maxValue=fmax(maxValue,(float)input[offset+ i]);\n"
" }\n"
" sum[lid]=maxValue;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" #pragma unroll\n"
" for(int i=SOFTMAX_LOCAL_SIZE/2; i>0; i >>= 1){\n"
" if (lid<i)\n"
" sum[lid]=fmax(sum[lid],sum[lid+i]);\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" maxValue=sum[0];\n"
" /*Compute Exp Sum*/\n"
" float sumValue=0;\n"
" for (int i=lid; i<inside_len; i+=SOFTMAX_LOCAL_SIZE) {\n"
" sumValue += exp((float)input[offset+ i]-maxValue);\n"
" }\n"
" sum[lid]=sumValue;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" #pragma unroll\n"
" for(int i=SOFTMAX_LOCAL_SIZE/2; i>0; i >>= 1){\n"
" if (lid<i)\n"
" sum[lid]=sum[lid]+sum[lid+i];\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" sumValue=sum[0];\n"
" #ifdef OUTPUT_TRANSPOSE\n"
" const int out_offset=(outside*shape.z+0)*shape.y+axis;\n"
" #endif\n"
" /*Compute Result */\n"
" for (int i=lid; i<inside_len; i+=SOFTMAX_LOCAL_SIZE) {\n"
" float value=exp((float)input[offset+ i]-maxValue)/sumValue;\n"
" #ifdef OUTPUT_TRANSPOSE\n"
" output[out_offset+ i*shape.y]=value;\n"
" #else\n"
" output[offset+ i]=value;\n"
" #endif\n"
" }\n"
" if(shape.z>inside_len){\n"
" for(int i=lid+inside_len; i<shape.z; i+=SOFTMAX_LOCAL_SIZE){\n"
" #ifdef OUTPUT_TRANSPOSE\n"
" output[out_offset+ i*shape.y]=(FLOAT)0;\n"
" #else\n"
" output[offset+ i]=(FLOAT)0;\n"
" #endif\n"
" }\n"
" }\n"
"}\n"
"// [N X Y4 4] -> [N Y X]\n"
"__kernel void trans_3d_buf(GLOBAL_SIZE_3_DIMS\n"
" __global const FLOAT* input,\n"
" __global FLOAT* output,\n"
" __private const int batch,\n"
" __private const int width,\n"
" __private const int height\n"
") {\n"
" int b=get_global_id(2);\n"
" int w=get_global_id(0);\n"
" int h=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM3(w,h,b);\n"
" w=w << 3;\n"
" h=h << 3;\n"
" \n"
" const int inp_offset=(b*width+w)*height+h;\n"
" const int out_offset=(b*height+h)*width+w;\n"
" FLOAT8 value_0=vload8(0,input+inp_offset);\n"
" FLOAT8 value_1=vload8(0,input+inp_offset+height);\n"
" FLOAT8 value_2=vload8(0,input+inp_offset+height+height);\n"
" FLOAT8 value_3=vload8(0,input+inp_offset+height+height+height);\n"
" FLOAT8 value_4=vload8(0,input+inp_offset+(height << 2));\n"
" FLOAT8 value_5=vload8(0,input+inp_offset+height*5);\n"
" FLOAT8 value_6=vload8(0,input+inp_offset+height*6);\n"
" FLOAT8 value_7=vload8(0,input+inp_offset+height*7);\n"
" \n"
" vstore8((FLOAT8){value_0.s0,value_1.s0,value_2.s0,value_3.s0,value_4.s0,value_5.s0,value_6.s0,value_7.s0},0,output+out_offset);\n"
" vstore8((FLOAT8){value_0.s1,value_1.s1,value_2.s1,value_3.s1,value_4.s1,value_5.s1,value_6.s1,value_7.s1},0,output+out_offset+width);\n"
" vstore8((FLOAT8){value_0.s2,value_1.s2,value_2.s2,value_3.s2,value_4.s2,value_5.s2,value_6.s2,value_7.s2},0,output+out_offset+width+width);\n"
" vstore8((FLOAT8){value_0.s3,value_1.s3,value_2.s3,value_3.s3,value_4.s3,value_5.s3,value_6.s3,value_7.s3},0,output+out_offset+width+width+width);\n"
" vstore8((FLOAT8){value_0.s4,value_1.s4,value_2.s4,value_3.s4,value_4.s4,value_5.s4,value_6.s4,value_7.s4},0,output+out_offset+(width << 2));\n"
" vstore8((FLOAT8){value_0.s5,value_1.s5,value_2.s5,value_3.s5,value_4.s5,value_5.s5,value_6.s5,value_7.s5},0,output+out_offset+width*5);\n"
" vstore8((FLOAT8){value_0.s6,value_1.s6,value_2.s6,value_3.s6,value_4.s6,value_5.s6,value_6.s6,value_7.s6},0,output+out_offset+width*6);\n"
" vstore8((FLOAT8){value_0.s7,value_1.s7,value_2.s7,value_3.s7,value_4.s7,value_5.s7,value_6.s7,value_7.s7},0,output+out_offset+width*7);\n"
"}\n"
"__kernel void clip_transpose_qkv(GLOBAL_SIZE_3_DIMS\n"
" __global const FLOAT *input,// [Batch*mNumHead,ROUND_UP(mHeadDim,tile),ROUND_UP(seqLen,tile)]\n"
" __global FLOAT *output,// [Batch,seqLen/4,mNumHead*mHeadDim,4]\n"
" __private const int tile,\n"
" __private const int seq_len,\n"
" __private const int seq_len_piece,\n"
" __private const int head_num,\n"
" __private const int head_dim,\n"
" __private const int batch,\n"
" __private const int seq_index\n"
") {\n"
" \n"
" const int sl=get_global_id(0); // seqLen_Piece_4\n"
" const int hd=get_global_id(1); // mHeadDim_4\n"
" const int z=get_global_id(2); // Batch*mNumHead\n"
" DEAL_NON_UNIFORM_DIM3(sl,hd,z);\n"
" \n"
" const int b=z/head_num;\n"
" const int hn=z % head_num;\n"
" \n"
" const int seq_len_4=(seq_len+3)/4;\n"
" \n"
" if(seq_index*seq_len_piece/4+sl >= seq_len_4) {\n"
" return;\n"
" }\n"
" const int seq_len_pack=seq_len_piece;//((seq_len+tile-1)/tile)*tile;\n"
" const int head_dim_pack=((head_dim+tile-1)/tile)*tile;\n"
" \n"
" const int offset_inp=((b*head_num+hn)*head_dim_pack+4*hd)*seq_len_pack+4*sl;\n"
" \n"
" const int offset_out=((((seq_index*seq_len_piece/4+sl)*batch+b)*head_num+hn)*head_dim+4*hd)*4;\n"
" // Q\n"
" FLOAT4 temp_0=vload4(0,input+offset_inp);\n"
" FLOAT4 temp_1=vload4(0,input+offset_inp+seq_len_pack);\n"
" FLOAT4 temp_2=vload4(0,input+offset_inp+2*seq_len_pack);\n"
" FLOAT4 temp_3=vload4(0,input+offset_inp+3*seq_len_pack);\n"
" \n"
" vstore4(temp_0,0,output+offset_out);\n"
" if(4*hd+1>head_dim) return;\n"
" vstore4(temp_1,0,output+offset_out+4);\n"
" if(4*hd+2>head_dim) return;\n"
" vstore4(temp_2,0,output+offset_out+8);\n"
" if(4*hd+3>head_dim) return;\n"
" vstore4(temp_3,0,output+offset_out+12);\n"
"}\n"
;
#endif
const char* performance = 
"#define MAD_V4(x, y) "" x = mad(y, x, y); "" y = mad(x, y, x); "" x = mad(y, x, y); "" y=mad(x,y,x);\n"
"#define MAD_V16(x, y) "" MAD_V4(x, y); "" MAD_V4(x, y); "" MAD_V4(x, y); "" MAD_V4(x,y);\n"
"#define MAD_V64(x, y) "" MAD_V16(x, y); "" MAD_V16(x, y); "" MAD_V16(x, y); "" MAD_V16(x,y);\n"
"#define MAD_V128(x, y) "" MAD_V64(x, y); "" MAD_V64(x, y); "" MAD_V64(x, y); "" MAD_V64(x,y);\n"
"#define MAD_V256(x, y) "" MAD_V128(x, y); "" MAD_V128(x, y); "" MAD_V128(x, y); "" MAD_V128(x,y);\n"
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"__kernel void float_precision(__global float* output_ptr,float mul_value) {\n"
" float mul_x=mul_value;\n"
" float mul_y=(float)get_local_id(0);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" output_ptr[get_global_id(0)]=mul_y;\n"
"}\n"
"__kernel void half4_precision(__global half* output_ptr,float mul_value) {\n"
" half mul=(half)mul_value;\n"
" half4 mul_x=(half4)(mul);\n"
" half4 mul_y=(half4)get_local_id(0);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" output_ptr[get_global_id(0)]=(mul_y.S0)+(mul_y.S1)+(mul_y.S2)+(mul_y.S3);\n"
"}\n"
;
const char* winogradTransformSource2_3_1 = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void winogradTransformSource(__read_only image2d_t uInput,// 0\n"
" __write_only image2d_t uOutput,__private const int unitWidth,\n"
" __private const int unitHeight,// 3\n"
" __private const int padX,__private const int padY,\n"
" __private const int srcWidth,// 6\n"
" __private const int srcHeight,__private const int srcChannelC4,\n"
" __private const int batchOffset) {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1));\n"
" if (pos.x<unitWidth*unitHeight && pos.y<srcChannelC4) {\n"
" int unitWidth_idx=pos.x % unitWidth;\n"
" int unitHeight_idx=pos.x/unitWidth;\n"
" int dstX=mad24(pos.y,unitWidth,unitWidth_idx);\n"
" {\n"
" int sxStart=unitWidth_idx*2-padX;\n"
" int syStart=unitHeight_idx*2-padY;\n"
" FLOAT4 S00;\n"
" FLOAT4 S10;\n"
" FLOAT4 S20;\n"
" FLOAT4 S30;\n"
" FLOAT4 S01;\n"
" FLOAT4 S11;\n"
" FLOAT4 S21;\n"
" FLOAT4 S31;\n"
" FLOAT4 S02;\n"
" FLOAT4 S12;\n"
" FLOAT4 S22;\n"
" FLOAT4 S32;\n"
" FLOAT4 S03;\n"
" FLOAT4 S13;\n"
" FLOAT4 S23;\n"
" FLOAT4 S33;\n"
" {\n"
" int sx=0+sxStart;\n"
" int sy=0+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S00=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=1+sxStart;\n"
" int sy=0+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S10=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=2+sxStart;\n"
" int sy=0+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S20=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=3+sxStart;\n"
" int sy=0+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S30=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=0+sxStart;\n"
" int sy=1+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S01=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=1+sxStart;\n"
" int sy=1+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S11=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=2+sxStart;\n"
" int sy=1+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S21=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=3+sxStart;\n"
" int sy=1+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S31=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=0+sxStart;\n"
" int sy=2+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S02=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=1+sxStart;\n"
" int sy=2+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S12=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=2+sxStart;\n"
" int sy=2+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S22=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=3+sxStart;\n"
" int sy=2+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S32=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=0+sxStart;\n"
" int sy=3+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S03=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=1+sxStart;\n"
" int sy=3+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S13=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=2+sxStart;\n"
" int sy=3+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S23=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=3+sxStart;\n"
" int sy=3+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S33=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" FLOAT4 m00=+S00-S02;\n"
" FLOAT4 m10=+S10-S12;\n"
" FLOAT4 m20=+S20-S22;\n"
" FLOAT4 m30=+S30-S32;\n"
" FLOAT4 m01=+(FLOAT)0.5f*S01+(FLOAT)0.5f*S02;\n"
" FLOAT4 m11=+(FLOAT)0.5f*S11+(FLOAT)0.5f*S12;\n"
" FLOAT4 m21=+(FLOAT)0.5f*S21+(FLOAT)0.5f*S22;\n"
" FLOAT4 m31=+(FLOAT)0.5f*S31+(FLOAT)0.5f*S32;\n"
" FLOAT4 m02=-(FLOAT)0.5f*S01+(FLOAT)0.5f*S02;\n"
" FLOAT4 m12=-(FLOAT)0.5f*S11+(FLOAT)0.5f*S12;\n"
" FLOAT4 m22=-(FLOAT)0.5f*S21+(FLOAT)0.5f*S22;\n"
" FLOAT4 m32=-(FLOAT)0.5f*S31+(FLOAT)0.5f*S32;\n"
" FLOAT4 m03=-S01+S03;\n"
" FLOAT4 m13=-S11+S13;\n"
" FLOAT4 m23=-S21+S23;\n"
" FLOAT4 m33=-S31+S33;\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*0),+m00-m20);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*1),+(FLOAT)0.5f*m10+(FLOAT)0.5f*m20);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*2),-(FLOAT)0.5f*m10+(FLOAT)0.5f*m20);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*3),-m10+m30);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*4),+m01-m21);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*5),+(FLOAT)0.5f*m11+(FLOAT)0.5f*m21);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*6),-(FLOAT)0.5f*m11+(FLOAT)0.5f*m21);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*7),-m11+m31);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*8),+m02-m22);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*9),+(FLOAT)0.5f*m12+(FLOAT)0.5f*m22);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*10),-(FLOAT)0.5f*m12+(FLOAT)0.5f*m22);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*11),-m12+m32);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*12),+m03-m23);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*13),+(FLOAT)0.5f*m13+(FLOAT)0.5f*m23);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*14),-(FLOAT)0.5f*m13+(FLOAT)0.5f*m23);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*15),-m13+m33);\n"
" }\n"
" }\n"
"}\n"
;
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* gemv_conv1x1_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_DIM2 "" __private int global_size_dim0,__private int global_size_dim1,\n"
"#define UNIFORM_BOUNDRY_CHECK(index0, index1) "" if(index0 >= global_size_dim0 || index1 >= global_size_dim1) { "" return; "" }\n"
"#define UCHAR16_TO_2CHAR16(a, b, c) "" a.s0 = (c.s0 >> 4) - 8; a.s1 = (c.s0 & 15) - 8; a.s2 = (c.s1 >> 4) - 8; a.s3 = (c.s1 & 15) - 8; a.s4 = (c.s2 >> 4) - 8; a.s5 = (c.s2 & 15) - 8; a.s6 = (c.s3 >> 4) - 8; a.s7 = (c.s3 & 15) - 8; "" a.s8 = (c.s4 >> 4) - 8; a.s9 = (c.s4 & 15) - 8; a.sa = (c.s5 >> 4) - 8; a.sb = (c.s5 & 15) - 8; a.sc = (c.s6 >> 4) - 8; a.sd = (c.s6 & 15) - 8; a.se = (c.s7 >> 4) - 8; a.sf = (c.s7 & 15) - 8; "" b.s0 = (c.s8 >> 4) - 8; b.s1 = (c.s8 & 15) - 8; b.s2 = (c.s9 >> 4) - 8; b.s3 = (c.s9 & 15) - 8; b.s4 = (c.sa >> 4) - 8; b.s5 = (c.sa & 15) - 8; b.s6 = (c.sb >> 4) - 8; b.s7 = (c.sb & 15) - 8; "" b.s8=(c.sc >> 4)-8; b.s9=(c.sc & 15)-8; b.sa=(c.sd >> 4)-8; b.sb=(c.sd & 15)-8; b.sc=(c.se >> 4)-8; b.sd=(c.se & 15)-8; b.se=(c.sf >> 4)-8; b.sf=(c.sf & 15)-8;\n"
"#define UCHAR8_TO_CHAR16(a, c) "" a.s0 = (c.s0 >> 4) - 8; a.s1 = (c.s0 & 15) - 8; a.s2 = (c.s1 >> 4) - 8; a.s3 = (c.s1 & 15) - 8; a.s4 = (c.s2 >> 4) - 8; a.s5 = (c.s2 & 15) - 8; a.s6 = (c.s3 >> 4) - 8; a.s7 = (c.s3 & 15) - 8; "" a.s8=(c.s4 >> 4)-8; a.s9=(c.s4 & 15)-8; a.sa=(c.s5 >> 4)-8; a.sb=(c.s5 & 15)-8; a.sc=(c.s6 >> 4)-8; a.sd=(c.s6 & 15)-8; a.se=(c.s7 >> 4)-8; a.sf=(c.s7 & 15)-8;\n"
"#define DOT16X16(a, b, c) "" c += dot(a.s0123, b.s0123); "" c += dot(a.s4567, b.s4567); "" c += dot(a.s89ab, b.s89ab); "" c += dot(a.scdef,b.scdef);\n"
"#ifdef INPUT_CHANNEL_LEAVE\n"
" #define PADZEROS(k, channel, data) {"" COMPUTE_FLOAT* ptr = (COMPUTE_FLOAT*)&data; "" int remain = k + 15 - channel; "" for(int r = remain; r >= 0; r--){ "" ptr[15 - r] = 0; "" } "" }\n"
"#else\n"
" #define PADZEROS(k,channel,data)\n"
"#endif\n"
"#if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE)\n"
"#define CHANNEL_PACK 32\n"
"#else\n"
"#define CHANNEL_PACK 16\n"
"#endif\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
"#define WEIGHT_STRIDE 16\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
"#define WEIGHT_STRIDE 8\n"
"#endif\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"#ifdef USE_IMAGE\n"
"inline COMPUTE_FLOAT16 readWeight(__read_only image2d_t weight,int ix,int iy,COMPUTE_FLOAT scale,COMPUTE_FLOAT offset){\n"
" return CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight,SAMPLER,(int2)(ix,iy))))*scale+offset;\n"
"}\n"
"#else\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
"inline COMPUTE_FLOAT16 readWeight(__global const char *weight,int ix,int iy,COMPUTE_FLOAT scale,COMPUTE_FLOAT offset){\n"
" return CONVERT_COMPUTE_FLOAT16(vload16(0,weight))*scale+offset;\n"
"}\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
"inline COMPUTE_FLOAT16 readWeight(__global const uchar *weight,int ix,int iy,COMPUTE_FLOAT scale,COMPUTE_FLOAT offset){\n"
" uchar16 charWeightsInt40=vload16(0,weight);\n"
" uchar8 charWeightsInt4=vload8(0,weight);\n"
" char16 charWeights=0;\n"
" UCHAR8_TO_CHAR16(charWeights,charWeightsInt4);\n"
" return CONVERT_COMPUTE_FLOAT16(charWeights)*scale+offset;\n"
"}\n"
"#endif\n"
"#endif\n"
"__kernel void gemv_conv_c4_buf(GLOBAL_SIZE_DIM2\n"
" __global const FLOAT* input,\n"
"#ifdef USE_IMAGE\n"
" __read_only image2d_t weight,\n"
"#else\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *weight,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar *weight,\n"
"#endif\n"
"#endif\n"
" __global const float *dequantScaleOffset,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT* output,\n"
" __private const int dstChannelC4,\n"
" __private const int srcChannelC4,\n"
" __private const int srcChannel,\n"
" __private const int bhw,\n"
" __private const int blockNum,\n"
" __private const int blockDim) {\n"
" const int x=get_global_id(0); //c/4\n"
" const int y=get_global_id(1); //b h w\n"
" UNIFORM_BOUNDRY_CHECK(x,y);\n"
" COMPUTE_FLOAT4 bias0=CONVERT_COMPUTE_FLOAT4(vload4(x,bias));\n"
" COMPUTE_FLOAT4 out0=bias0;\n"
" int idn=x << 2;\n"
" int idm=y;\n"
" \n"
" int input_offset0=idm*4;\n"
" int out_offset=(x*bhw+idm)*4;\n"
"#ifndef USE_IMAGE\n"
" int weight_offset=x*4*WEIGHT_STRIDE;\n"
" int weight_oc_offset=dstChannelC4*4*WEIGHT_STRIDE;\n"
"#endif\n"
" const int loop=(blockDim+CHANNEL_PACK-1)/CHANNEL_PACK;\n"
"#ifdef INPUT_CHANNEL_LEAVE\n"
" const int loop_end=max(loop-1,0);\n"
"#else\n"
" const int loop_end=loop;\n"
"#endif\n"
" \n"
" for (int i=0; i<blockNum; ++i){\n"
" int kindex=i*dstChannelC4*4*2;\n"
" COMPUTE_FLOAT8 ScaleOffset=CONVERT_COMPUTE_FLOAT8(vload8(x,dequantScaleOffset+kindex));\n"
" for (int j=0; j<loop_end; ++j) {\n"
" int k=i*loop+j;\n"
" #if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE)\n"
" int k32=k << 5;\n"
" COMPUTE_FLOAT16 weights00,weights01,weights10,weights11,weights20,weights21,weights30,weights31;\n"
" {\n"
" uchar16 charWeightsInt40=as_uchar16(read_imagei(weight,SAMPLER,(int2)(idn,k)));\n"
" uchar16 charWeightsInt41=as_uchar16(read_imagei(weight,SAMPLER,(int2)(idn+1,k)));\n"
" uchar16 charWeightsInt42=as_uchar16(read_imagei(weight,SAMPLER,(int2)(idn+2,k)));\n"
" uchar16 charWeightsInt43=as_uchar16(read_imagei(weight,SAMPLER,(int2)(idn+3,k)));\n"
" char16 charWeights0,charWeights1;\n"
" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt40);\n"
" weights00=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s0+ScaleOffset.s1;\n"
" weights01=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s0+ScaleOffset.s1;\n"
" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt41);\n"
" weights10=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s2+ScaleOffset.s3;\n"
" weights11=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s2+ScaleOffset.s3;\n"
" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt42);\n"
" weights20=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s4+ScaleOffset.s5;\n"
" weights21=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s4+ScaleOffset.s5;\n"
" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt43);\n"
" weights30=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s6+ScaleOffset.s7;\n"
" weights31=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s6+ScaleOffset.s7;\n"
" }\n"
" {\n"
" COMPUTE_FLOAT16 in0=CONVERT_COMPUTE_FLOAT16(vload16(0,input+k32));\n"
" COMPUTE_FLOAT16 in1=CONVERT_COMPUTE_FLOAT16(vload16(0,input+k32+16));\n"
" DOT16X16(in0,weights00,out0.s0);DOT16X16(in1,weights01,out0.s0);\n"
" DOT16X16(in0,weights10,out0.s1);DOT16X16(in1,weights11,out0.s1);\n"
" DOT16X16(in0,weights20,out0.s2);DOT16X16(in1,weights21,out0.s2);\n"
" DOT16X16(in0,weights30,out0.s3);DOT16X16(in1,weights31,out0.s3);\n"
" }\n"
" #else\n"
" COMPUTE_FLOAT16 weights0,weights1,weights2,weights3;\n"
" #ifdef USE_IMAGE\n"
" weights0=readWeight(weight,idn,k,ScaleOffset.s0,ScaleOffset.s1);\n"
" weights1=readWeight(weight,idn+1,k,ScaleOffset.s2,ScaleOffset.s3);\n"
" weights2=readWeight(weight,idn+2,k,ScaleOffset.s4,ScaleOffset.s5);\n"
" weights3=readWeight(weight,idn+3,k,ScaleOffset.s6,ScaleOffset.s7);\n"
" #else\n"
" weights0=readWeight(weight+weight_offset+k*weight_oc_offset,0,0,ScaleOffset.s0,ScaleOffset.s1);\n"
" weights1=readWeight(weight+weight_offset+k*weight_oc_offset+WEIGHT_STRIDE,0,0,ScaleOffset.s2,ScaleOffset.s3);\n"
" weights2=readWeight(weight+weight_offset+k*weight_oc_offset+2*WEIGHT_STRIDE,0,0,ScaleOffset.s4,ScaleOffset.s5);\n"
" weights3=readWeight(weight+weight_offset+k*weight_oc_offset+3*WEIGHT_STRIDE,0,0,ScaleOffset.s6,ScaleOffset.s7);\n"
" #endif\n"
" {\n"
" COMPUTE_FLOAT16 in=CONVERT_COMPUTE_FLOAT16(vload16(k,input));\n"
" DOT16X16(in,weights0,out0.s0);\n"
" DOT16X16(in,weights1,out0.s1);\n"
" DOT16X16(in,weights2,out0.s2);\n"
" DOT16X16(in,weights3,out0.s3);\n"
" }\n"
" #endif\n"
" }\n"
" #ifdef INPUT_CHANNEL_LEAVE\n"
" {\n"
" int k=i*loop+loop_end;\n"
" #if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE)\n"
" int k8=k << 3;\n"
" COMPUTE_FLOAT16 weights00,weights01,weights10,weights11,weights20,weights21,weights30,weights31;\n"
" {\n"
" uchar16 charWeightsInt40=as_uchar16(read_imagei(weight,SAMPLER,(int2)(idn,k)));\n"
" uchar16 charWeightsInt41=as_uchar16(read_imagei(weight,SAMPLER,(int2)(idn+1,k)));\n"
" uchar16 charWeightsInt42=as_uchar16(read_imagei(weight,SAMPLER,(int2)(idn+2,k)));\n"
" uchar16 charWeightsInt43=as_uchar16(read_imagei(weight,SAMPLER,(int2)(idn+3,k)));\n"
" char16 charWeights0,charWeights1;\n"
" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt40);\n"
" weights00=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s0+ScaleOffset.s1;\n"
" weights01=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s0+ScaleOffset.s1;\n"
" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt41);\n"
" weights10=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s2+ScaleOffset.s3;\n"
" weights11=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s2+ScaleOffset.s3;\n"
" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt42);\n"
" weights20=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s4+ScaleOffset.s5;\n"
" weights21=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s4+ScaleOffset.s5;\n"
" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt43);\n"
" weights30=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s6+ScaleOffset.s7;\n"
" weights31=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s6+ScaleOffset.s7;\n"
" \n"
" PADZEROS(k,srcChannel,weights00);PADZEROS(k+16,srcChannel,weights01);\n"
" PADZEROS(k,srcChannel,weights10);PADZEROS(k+16,srcChannel,weights11);\n"
" PADZEROS(k,srcChannel,weights20);PADZEROS(k+16,srcChannel,weights21);\n"
" PADZEROS(k,srcChannel,weights30);PADZEROS(k+16,srcChannel,weights31);\n"
" }\n"
" {\n"
" COMPUTE_FLOAT16 in0,in1;\n"
" in0.s0123=CONVERT_COMPUTE_FLOAT4(vload4(0,input+k8*4));\n"
" in0.s4567=CONVERT_COMPUTE_FLOAT4(k8+1<srcChannelC4 ? vload4(0,input+(k8+1)*4) : (FLOAT4)0);\n"
" in0.s89ab=CONVERT_COMPUTE_FLOAT4(k8+2<srcChannelC4 ? vload4(0,input+(k8+2)*4) : (FLOAT4)0);\n"
" in0.scdef=CONVERT_COMPUTE_FLOAT4(k8+3<srcChannelC4 ? vload4(0,input+(k8+3)*4) : (FLOAT4)0);\n"
" in1.s0123=CONVERT_COMPUTE_FLOAT4(k8+4<srcChannelC4 ? vload4(0,input+(k8+4)*4) : (FLOAT4)0);\n"
" in1.s4567=CONVERT_COMPUTE_FLOAT4(k8+5<srcChannelC4 ? vload4(0,input+(k8+5)*4) : (FLOAT4)0);\n"
" in1.s89ab=CONVERT_COMPUTE_FLOAT4(k8+6<srcChannelC4 ? vload4(0,input+(k8+6)*4) : (FLOAT4)0);\n"
" in1.scdef=CONVERT_COMPUTE_FLOAT4(k8+7<srcChannelC4 ? vload4(0,input+(k8+7)*4) : (FLOAT4)0);\n"
" DOT16X16(in0,weights00,out0.s0);DOT16X16(in1,weights01,out0.s0);\n"
" DOT16X16(in0,weights10,out0.s1);DOT16X16(in1,weights11,out0.s1);\n"
" DOT16X16(in0,weights20,out0.s2);DOT16X16(in1,weights21,out0.s2);\n"
" DOT16X16(in0,weights30,out0.s3);DOT16X16(in1,weights31,out0.s3);\n"
" }\n"
" #else\n"
" int k4=k << 2;\n"
" COMPUTE_FLOAT16 weights0,weights1,weights2,weights3;\n"
" #ifdef USE_IMAGE\n"
" weights0=readWeight(weight,idn,k,ScaleOffset.s0,ScaleOffset.s1);\n"
" weights1=readWeight(weight,idn+1,k,ScaleOffset.s2,ScaleOffset.s3);\n"
" weights2=readWeight(weight,idn+2,k,ScaleOffset.s4,ScaleOffset.s5);\n"
" weights3=readWeight(weight,idn+3,k,ScaleOffset.s6,ScaleOffset.s7);\n"
" #else\n"
" weights0=readWeight(weight+weight_offset+k*weight_oc_offset,0,0,ScaleOffset.s0,ScaleOffset.s1);\n"
" weights1=readWeight(weight+weight_offset+k*weight_oc_offset+WEIGHT_STRIDE,0,0,ScaleOffset.s2,ScaleOffset.s3);\n"
" weights2=readWeight(weight+weight_offset+k*weight_oc_offset+2*WEIGHT_STRIDE,0,0,ScaleOffset.s4,ScaleOffset.s5);\n"
" weights3=readWeight(weight+weight_offset+k*weight_oc_offset+3*WEIGHT_STRIDE,0,0,ScaleOffset.s6,ScaleOffset.s7);\n"
" #endif\n"
" PADZEROS(k,srcChannel,weights0);\n"
" PADZEROS(k,srcChannel,weights1);\n"
" PADZEROS(k,srcChannel,weights2);\n"
" PADZEROS(k,srcChannel,weights3);\n"
" {\n"
" COMPUTE_FLOAT16 in;\n"
" in.s0123=CONVERT_COMPUTE_FLOAT4(vload4(0,input+k4*4));\n"
" in.s4567=CONVERT_COMPUTE_FLOAT4(k4+1<srcChannelC4 ? vload4(0,input+(k4+1)*4) : (FLOAT4)0);\n"
" in.s89ab=CONVERT_COMPUTE_FLOAT4(k4+2<srcChannelC4 ? vload4(0,input+(k4+2)*4) : (FLOAT4)0);\n"
" in.scdef=CONVERT_COMPUTE_FLOAT4(k4+3<srcChannelC4 ? vload4(0,input+(k4+3)*4) : (FLOAT4)0);\n"
" DOT16X16(in,weights0,out0.s0);\n"
" DOT16X16(in,weights1,out0.s1);\n"
" DOT16X16(in,weights2,out0.s2);\n"
" DOT16X16(in,weights3,out0.s3);\n"
" }\n"
" #endif\n"
" }\n"
" #endif\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
"}\n"
"__kernel void gemv_conv_c2_buf(GLOBAL_SIZE_DIM2\n"
" __global const FLOAT* input,\n"
"#ifdef USE_IMAGE\n"
" __read_only image2d_t weight,\n"
"#else\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *weight,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar *weight,\n"
"#endif\n"
"#endif\n"
" __global const float *dequantScaleOffset,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT* output,\n"
" __private const int dstChannelC4,\n"
" __private const int srcChannelC4,\n"
" __private const int srcChannel,\n"
" __private const int bhw,\n"
" __private const int blockNum,\n"
" __private const int blockDim) {\n"
" const int x=get_global_id(0); //c/2\n"
" const int y=get_global_id(1); //b h w\n"
" UNIFORM_BOUNDRY_CHECK(x,y);\n"
" \n"
" int idn=x << 1;\n"
" int idm=y;\n"
" COMPUTE_FLOAT2 bias0=CONVERT_COMPUTE_FLOAT2(vload2(x,bias));\n"
" COMPUTE_FLOAT2 out0=bias0;\n"
" int input_offset0=idm*4;\n"
" int out_offset=((x*2)/4*bhw+idm)*4+((x*2) % 4);\n"
"#ifndef USE_IMAGE\n"
" int weight_offset=x*2*WEIGHT_STRIDE;\n"
" int weight_oc_offset=dstChannelC4*4*WEIGHT_STRIDE;\n"
"#endif\n"
" const int loop=(blockDim+CHANNEL_PACK-1)/CHANNEL_PACK;\n"
"#ifdef INPUT_CHANNEL_LEAVE\n"
" const int loop_end=max(loop-1,0);\n"
"#else\n"
" const int loop_end=loop;\n"
"#endif\n"
" for (int i=0; i<blockNum; ++i){\n"
" int kindex=i*dstChannelC4*4*2;\n"
" COMPUTE_FLOAT4 ScaleOffset=CONVERT_COMPUTE_FLOAT4(vload4(x,dequantScaleOffset+kindex));\n"
" for (int j=0; j<loop_end; ++j) {\n"
" int k=i*loop+j;\n"
" #if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE)\n"
" int k32=k << 5;\n"
" COMPUTE_FLOAT16 weights00,weights01,weights10,weights11;\n"
" {\n"
" uchar16 charWeightsInt40=as_uchar16(read_imagei(weight,SAMPLER,(int2)(idn,k)));\n"
" uchar16 charWeightsInt41=as_uchar16(read_imagei(weight,SAMPLER,(int2)(idn+1,k)));\n"
" char16 charWeights0,charWeights1;\n"
" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt40);\n"
" weights00=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s0+ScaleOffset.s1;\n"
" weights01=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s0+ScaleOffset.s1;\n"
" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt41);\n"
" weights10=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s2+ScaleOffset.s3;\n"
" weights11=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s2+ScaleOffset.s3;\n"
" }\n"
" {\n"
" COMPUTE_FLOAT16 in0=CONVERT_COMPUTE_FLOAT16(vload16(0,input+k32));\n"
" COMPUTE_FLOAT16 in1=CONVERT_COMPUTE_FLOAT16(vload16(0,input+k32+16));\n"
" DOT16X16(in0,weights00,out0.s0);DOT16X16(in1,weights01,out0.s0);\n"
" DOT16X16(in0,weights10,out0.s1);DOT16X16(in1,weights11,out0.s1);\n"
" }\n"
" #else\n"
" COMPUTE_FLOAT16 weights0,weights1;\n"
" #ifdef USE_IMAGE\n"
" weights0=readWeight(weight,idn,k,ScaleOffset.s0,ScaleOffset.s1);\n"
" weights1=readWeight(weight,idn+1,k,ScaleOffset.s2,ScaleOffset.s3);\n"
" #else\n"
" weights0=readWeight(weight+weight_offset+k*weight_oc_offset,0,0,ScaleOffset.s0,ScaleOffset.s1);\n"
" weights1=readWeight(weight+weight_offset+k*weight_oc_offset+WEIGHT_STRIDE,0,0,ScaleOffset.s2,ScaleOffset.s3);\n"
" #endif\n"
" {\n"
" COMPUTE_FLOAT16 in=CONVERT_COMPUTE_FLOAT16(vload16(k,input));\n"
" DOT16X16(in,weights0,out0.s0);\n"
" DOT16X16(in,weights1,out0.s1);\n"
" }\n"
" #endif\n"
" }\n"
" #ifdef INPUT_CHANNEL_LEAVE\n"
" {\n"
" int k=i*loop+loop_end;\n"
" #if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE)\n"
" int k8=k << 3;\n"
" COMPUTE_FLOAT16 weights00,weights01,weights10,weights11;\n"
" {\n"
" uchar16 charWeightsInt40=as_uchar16(read_imagei(weight,SAMPLER,(int2)(idn,k)));\n"
" uchar16 charWeightsInt41=as_uchar16(read_imagei(weight,SAMPLER,(int2)(idn+1,k)));\n"
" char16 charWeights0,charWeights1;\n"
" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt40);\n"
" weights00=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s0+ScaleOffset.s1;\n"
" weights01=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s0+ScaleOffset.s1;\n"
" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt41);\n"
" weights10=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s2+ScaleOffset.s3;\n"
" weights11=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s2+ScaleOffset.s3;\n"
" \n"
" PADZEROS(k,srcChannel,weights00);PADZEROS(k+16,srcChannel,weights01);\n"
" PADZEROS(k,srcChannel,weights10);PADZEROS(k+16,srcChannel,weights11);\n"
" }\n"
" {\n"
" COMPUTE_FLOAT16 in0,in1;\n"
" in0.s0123=CONVERT_COMPUTE_FLOAT4(vload4(0,input+k8*4));\n"
" in0.s4567=CONVERT_COMPUTE_FLOAT4(k8+1<srcChannelC4 ? vload4(0,input+(k8+1)*4) : (FLOAT4)0);\n"
" in0.s89ab=CONVERT_COMPUTE_FLOAT4(k8+2<srcChannelC4 ? vload4(0,input+(k8+2)*4) : (FLOAT4)0);\n"
" in0.scdef=CONVERT_COMPUTE_FLOAT4(k8+3<srcChannelC4 ? vload4(0,input+(k8+3)*4) : (FLOAT4)0);\n"
" in1.s0123=CONVERT_COMPUTE_FLOAT4(k8+4<srcChannelC4 ? vload4(0,input+(k8+4)*4) : (FLOAT4)0);\n"
" in1.s4567=CONVERT_COMPUTE_FLOAT4(k8+5<srcChannelC4 ? vload4(0,input+(k8+5)*4) : (FLOAT4)0);\n"
" in1.s89ab=CONVERT_COMPUTE_FLOAT4(k8+6<srcChannelC4 ? vload4(0,input+(k8+6)*4) : (FLOAT4)0);\n"
" in1.scdef=CONVERT_COMPUTE_FLOAT4(k8+7<srcChannelC4 ? vload4(0,input+(k8+7)*4) : (FLOAT4)0);\n"
" DOT16X16(in0,weights00,out0.s0);DOT16X16(in1,weights01,out0.s0);\n"
" DOT16X16(in0,weights10,out0.s1);DOT16X16(in1,weights11,out0.s1);\n"
" }\n"
" #else\n"
" int k4=k << 2;\n"
" COMPUTE_FLOAT16 weights0,weights1;\n"
" #ifdef USE_IMAGE\n"
" weights0=readWeight(weight,idn,k,ScaleOffset.s0,ScaleOffset.s1);\n"
" weights1=readWeight(weight,idn+1,k,ScaleOffset.s2,ScaleOffset.s3);\n"
" #else\n"
" weights0=readWeight(weight+weight_offset+k*weight_oc_offset,0,0,ScaleOffset.s0,ScaleOffset.s1);\n"
" weights1=readWeight(weight+weight_offset+k*weight_oc_offset+WEIGHT_STRIDE,0,0,ScaleOffset.s2,ScaleOffset.s3);\n"
" #endif\n"
" PADZEROS(k,srcChannel,weights0);\n"
" PADZEROS(k,srcChannel,weights1);\n"
" {\n"
" COMPUTE_FLOAT16 in;\n"
" in.s0123=CONVERT_COMPUTE_FLOAT4(vload4(0,input+k4*4));\n"
" in.s4567=CONVERT_COMPUTE_FLOAT4(k4+1<srcChannelC4 ? vload4(0,input+(k4+1)*4) : (FLOAT4)0);\n"
" in.s89ab=CONVERT_COMPUTE_FLOAT4(k4+2<srcChannelC4 ? vload4(0,input+(k4+2)*4) : (FLOAT4)0);\n"
" in.scdef=CONVERT_COMPUTE_FLOAT4(k4+3<srcChannelC4 ? vload4(0,input+(k4+3)*4) : (FLOAT4)0);\n"
" DOT16X16(in,weights0,out0.s0);\n"
" DOT16X16(in,weights1,out0.s1);\n"
" }\n"
" #endif\n"
" }\n"
" #endif\n"
" }\n"
" \n"
"#ifdef RELU\n"
" out0=fmax(out0,(COMPUTE_FLOAT2)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(COMPUTE_FLOAT2)0,(COMPUTE_FLOAT2)6);\n"
"#endif\n"
" vstore2(CONVERT_FLOAT2(out0),0,output+out_offset);\n"
"}\n"
"__kernel void gemv_conv_c1_buf(GLOBAL_SIZE_DIM2\n"
" __global const FLOAT* input,\n"
"#ifdef USE_IMAGE\n"
" __read_only image2d_t weight,\n"
"#else\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *weight,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar *weight,\n"
"#endif\n"
"#endif\n"
" __global const float *dequantScaleOffset,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT* output,\n"
" __private const int dstChannelC4,\n"
" __private const int srcChannelC4,\n"
" __private const int srcChannel,\n"
" __private const int bhw,\n"
" __private const int blockNum,\n"
" __private const int blockDim) {\n"
" const int x=get_global_id(0); //c\n"
" const int y=get_global_id(1); //b h w\n"
" UNIFORM_BOUNDRY_CHECK(x,y);\n"
" int idn=x;\n"
" int idm=y;\n"
" COMPUTE_FLOAT bias0=bias[x];\n"
" COMPUTE_FLOAT out0=bias0;\n"
" \n"
" int input_offset0=idm*4;\n"
" \n"
" int out_offset=((x/4)*bhw+idm)*4+(x % 4);\n"
"#ifndef USE_IMAGE\n"
" int weight_offset=x*WEIGHT_STRIDE;\n"
" int weight_oc_offset=dstChannelC4*4*WEIGHT_STRIDE;\n"
"#endif\n"
" const int loop=(blockDim+CHANNEL_PACK-1)/CHANNEL_PACK;\n"
"#ifdef INPUT_CHANNEL_LEAVE\n"
" const int loop_end=max(loop-1,0);\n"
"#else\n"
" const int loop_end=loop;\n"
"#endif\n"
" \n"
" for (int i=0; i<blockNum; ++i){\n"
" int kindex=i*dstChannelC4*4*2;\n"
" COMPUTE_FLOAT2 ScaleOffset=CONVERT_COMPUTE_FLOAT2(vload2(x,dequantScaleOffset+kindex));\n"
" for (int j=0; j<loop_end; ++j) {\n"
" int k=i*loop+j;\n"
" #if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE)\n"
" int k32=k << 5;\n"
" COMPUTE_FLOAT16 weights00,weights01;\n"
" {\n"
" uchar16 charWeightsInt40=as_uchar16(read_imagei(weight,SAMPLER,(int2)(idn,k)));\n"
" char16 charWeights0,charWeights1;\n"
" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt40);\n"
" weights00=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s0+ScaleOffset.s1;\n"
" weights01=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s0+ScaleOffset.s1;\n"
" }\n"
" {\n"
" COMPUTE_FLOAT16 in0=CONVERT_COMPUTE_FLOAT16(vload16(0,input+k32));\n"
" COMPUTE_FLOAT16 in1=CONVERT_COMPUTE_FLOAT16(vload16(0,input+k32+16));\n"
" DOT16X16(in0,weights00,out0);DOT16X16(in1,weights01,out0);\n"
" }\n"
" #else\n"
" COMPUTE_FLOAT16 weights;\n"
" #ifdef USE_IMAGE\n"
" weights=readWeight(weight,idn,k,ScaleOffset.s0,ScaleOffset.s1);\n"
" #else\n"
" weights=readWeight(weight+weight_offset+k*weight_oc_offset,0,0,ScaleOffset.s0,ScaleOffset.s1);\n"
" #endif\n"
" {\n"
" COMPUTE_FLOAT16 in=CONVERT_COMPUTE_FLOAT16(vload16(k,input));\n"
" DOT16X16(in,weights,out0);\n"
" }\n"
" #endif\n"
" }\n"
" #ifdef INPUT_CHANNEL_LEAVE\n"
" {\n"
" int k=i*loop+loop_end;\n"
" #if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE)\n"
" int k8=k << 3;\n"
" COMPUTE_FLOAT16 weights00,weights01;\n"
" {\n"
" uchar16 charWeightsInt40=as_uchar16(read_imagei(weight,SAMPLER,(int2)(idn,k)));\n"
" char16 charWeights0,charWeights1;\n"
" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt40);\n"
" weights00=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s0+ScaleOffset.s1;\n"
" weights01=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s0+ScaleOffset.s1;\n"
" \n"
" PADZEROS(k,srcChannel,weights00);PADZEROS(k+16,srcChannel,weights01);\n"
" }\n"
" {\n"
" COMPUTE_FLOAT16 in0,in1;\n"
" in0.s0123=CONVERT_COMPUTE_FLOAT4(vload4(0,input+k8*4));\n"
" in0.s4567=CONVERT_COMPUTE_FLOAT4(k8+1<srcChannelC4 ? vload4(0,input+(k8+1)*4) : (FLOAT4)0);\n"
" in0.s89ab=CONVERT_COMPUTE_FLOAT4(k8+2<srcChannelC4 ? vload4(0,input+(k8+2)*4) : (FLOAT4)0);\n"
" in0.scdef=CONVERT_COMPUTE_FLOAT4(k8+3<srcChannelC4 ? vload4(0,input+(k8+3)*4) : (FLOAT4)0);\n"
" in1.s0123=CONVERT_COMPUTE_FLOAT4(k8+4<srcChannelC4 ? vload4(0,input+(k8+4)*4) : (FLOAT4)0);\n"
" in1.s4567=CONVERT_COMPUTE_FLOAT4(k8+5<srcChannelC4 ? vload4(0,input+(k8+5)*4) : (FLOAT4)0);\n"
" in1.s89ab=CONVERT_COMPUTE_FLOAT4(k8+6<srcChannelC4 ? vload4(0,input+(k8+6)*4) : (FLOAT4)0);\n"
" in1.scdef=CONVERT_COMPUTE_FLOAT4(k8+7<srcChannelC4 ? vload4(0,input+(k8+7)*4) : (FLOAT4)0);\n"
" DOT16X16(in0,weights00,out0);DOT16X16(in1,weights01,out0);\n"
" }\n"
" #else\n"
" int k4=k << 2;\n"
" COMPUTE_FLOAT16 weights;\n"
" #ifdef USE_IMAGE\n"
" weights=readWeight(weight,idn,k,ScaleOffset.s0,ScaleOffset.s1);\n"
" #else\n"
" weights=readWeight(weight+weight_offset+k*weight_oc_offset,0,0,ScaleOffset.s0,ScaleOffset.s1);\n"
" #endif\n"
" PADZEROS(k,srcChannel,weights);\n"
" {\n"
" COMPUTE_FLOAT16 in;\n"
" in.s0123=CONVERT_COMPUTE_FLOAT4(vload4(0,input+k4*4));\n"
" in.s4567=CONVERT_COMPUTE_FLOAT4(k4+1<srcChannelC4 ? vload4(0,input+(k4+1)*4) : (FLOAT4)0);\n"
" in.s89ab=CONVERT_COMPUTE_FLOAT4(k4+2<srcChannelC4 ? vload4(0,input+(k4+2)*4) : (FLOAT4)0);\n"
" in.scdef=CONVERT_COMPUTE_FLOAT4(k4+3<srcChannelC4 ? vload4(0,input+(k4+3)*4) : (FLOAT4)0);\n"
" DOT16X16(in,weights,out0);\n"
" }\n"
" #endif\n"
" }\n"
" #endif\n"
" }\n"
" \n"
"#ifdef RELU\n"
" out0=fmax(out0,(COMPUTE_FLOAT)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(COMPUTE_FLOAT)0,(COMPUTE_FLOAT)6);\n"
"#endif\n"
" output[out_offset]=out0;\n"
"}\n"
;
#endif
const char* raster = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0,__private const int global_size_dim1,\n"
"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"__kernel void buffer_set_zero(\n"
" GLOBAL_SIZE_2_DIMS\n"
" __global OUTPUT_TYPE *output\n"
" ) {\n"
" const int x=get_global_id(0);\n"
" const int y=get_global_id(1);\n"
" \n"
" DEAL_NON_UNIFORM_DIM2(x,y);\n"
" \n"
" output[y*global_size_dim0+x]=(OUTPUT_TYPE)(0);\n"
"}\n"
"__kernel void image_set_zero(\n"
" GLOBAL_SIZE_2_DIMS\n"
" __write_only image2d_t output\n"
" ) {\n"
" const int x=get_global_id(0);\n"
" const int y=get_global_id(1);\n"
" \n"
" DEAL_NON_UNIFORM_DIM2(x,y);\n"
" WI_DATA(output,(int2)(x,y),(OUTPUT_TYPE_I4)(0));\n"
"}\n"
"__kernel void raster_buffer_direct(\n"
" GLOBAL_SIZE_3_DIMS\n"
" __read_only image2d_t input,\n"
" __private const int inputOffset,\n"
" __private const int combineSrcOffset,\n"
" __private const int inputStride0,\n"
" __private const int inputStride1,\n"
" __private const int inputStride2,\n"
" __private const int src_width,\n"
" __private const int src_height,\n"
" __private const int src_channel,\n"
" __global OUTPUT_TYPE *output,\n"
" __private const int outputOffset,\n"
" __private const int combineDstOffset,\n"
" __private const int outputStride0,\n"
" __private const int outputStride1,\n"
" __private const int outputStride2,\n"
" __private const int global_size0\n"
" ) {\n"
" const int idx=get_global_id(0);\n"
" const int y=get_global_id(1);\n"
" const int z=get_global_id(2);\n"
" \n"
" DEAL_NON_UNIFORM_DIM3(idx,y,z);\n"
" const int x=idx % global_size0;\n"
" const int id=idx/global_size0;\n"
" \n"
" int inputIndex=inputOffset+id*combineSrcOffset+z*inputStride0+y*inputStride1+x*inputStride2;\n"
" int outputIndex=outputOffset+id*combineDstOffset+z*outputStride0+y*outputStride1+x*outputStride2;\n"
"#ifdef INPUT_DATA_FORMAT_NHWC\n"
" int in_c=inputIndex % src_channel; inputIndex /= src_channel;\n"
" int in_w=inputIndex % src_width; inputIndex /= src_width;\n"
" int in_h=inputIndex % src_height;\n"
" int in_b=inputIndex/src_height;\n"
"#else\n"
" int in_w=inputIndex % src_width; inputIndex /= src_width;\n"
" int in_h=inputIndex % src_height; inputIndex /= src_height;\n"
" int in_c=inputIndex % src_channel;\n"
" int in_b=inputIndex/src_channel;\n"
"#endif\n"
" int2 coord=(int2)((in_c/4)*src_width+in_w,in_b*src_height+in_h);\n"
" INPUT_TYPE_I4 value=RI_DATA(input,SAMPLER,coord);\n"
" INPUT_TYPE_I* value_ptr=(INPUT_TYPE_I*)&value;\n"
" output[outputIndex]=(OUTPUT_TYPE)value_ptr[in_c % 4];\n"
"}\n"
"__kernel void raster_image(\n"
" GLOBAL_SIZE_3_DIMS\n"
" __read_only image2d_t input,\n"
" __private const int inputOffset,\n"
" __private const int inputStride0,\n"
" __private const int inputStride1,\n"
" __private const int inputStride2,\n"
" __private const int inputHeight,\n"
" __private const int inputWidth,\n"
" __private const int inputChannel,\n"
" __write_only image2d_t output,\n"
" __private const int outputOffset,\n"
" __private const int outputStride0,\n"
" __private const int outputStride1,\n"
" __private const int outputStride2,\n"
" __private const int outputHeight,\n"
" __private const int outputWidth,\n"
" __private const int outputChannel\n"
" ) {\n"
" const int x=get_global_id(0);\n"
" const int y=get_global_id(1);\n"
" const int z=get_global_id(2);\n"
" \n"
" DEAL_NON_UNIFORM_DIM3(x,y,z);\n"
" \n"
" int inputIndex=inputOffset+(z*inputStride0+y*inputStride1+x*inputStride2)*4;\n"
" int outputIndex=outputOffset+(z*outputStride0+y*outputStride1+x*outputStride2)*4;\n"
" int inp_idx_n=inputIndex/((inputChannel+3)/4*inputHeight*inputWidth*4);\n"
" int inputIndex_left=inputIndex % ((inputChannel+3)/4*inputHeight*inputWidth*4);\n"
" int inp_idx_c4=inputIndex_left/(inputHeight*inputWidth*4);\n"
" inputIndex_left=inputIndex_left % (inputHeight*inputWidth*4);\n"
" int inp_idx_h=inputIndex_left/(inputWidth*4);\n"
" inputIndex_left=inputIndex_left % (inputWidth*4);\n"
" int inp_idx_w=inputIndex_left/4;\n"
" \n"
" int out_idx_n=outputIndex/((outputChannel+3)/4*outputHeight*outputWidth*4);\n"
" int outputIndex_left=outputIndex % ((outputChannel+3)/4*outputHeight*outputWidth*4);\n"
" int out_idx_c4=outputIndex_left/(outputHeight*outputWidth*4);\n"
" outputIndex_left=outputIndex_left % (outputHeight*outputWidth*4);\n"
" int out_idx_h=outputIndex_left/(outputWidth*4);\n"
" outputIndex_left=outputIndex_left % (outputWidth*4);\n"
" int out_idx_w=outputIndex_left/4;\n"
" \n"
" int inp_idx0=inp_idx_c4*inputWidth+inp_idx_w;\n"
" int inp_idx1=inp_idx_n*inputHeight+inp_idx_h;\n"
" int out_idx0=out_idx_c4*outputWidth+out_idx_w;\n"
" int out_idx1=out_idx_n*outputHeight+out_idx_h;\n"
" INPUT_TYPE_I4 out=RI_DATA(input,SAMPLER,(int2)(inp_idx0,inp_idx1));\n"
" WI_DATA(output,(int2)(out_idx0,out_idx1),CONVERT_OUTPUT_I4(out));\n"
"}\n"
;
#ifndef MNN_OPENCL_BUFFER_CLOSED
#ifdef MNN_SUPPORT_INTEL_SUBGROUP
const char* conv_2d_c1_subgroup_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n"
"#ifdef MNN_SUPPORT_FP16\n"
"#define GROUP_READ(ptr,offset) as_half(intel_sub_group_block_read_us((const __global ushort*)(ptr)+(offset)))\n"
"#define GROUP_READ2(ptr,offset) as_half2(intel_sub_group_block_read_us2((const __global ushort*)(ptr)+(offset)))\n"
"#define GROUP_READ4(ptr,offset) as_half4(intel_sub_group_block_read_us4((const __global ushort*)(ptr)+(offset)))\n"
"#define GROUP_READ8(ptr,offset) as_half8(intel_sub_group_block_read_us8((const __global ushort*)(ptr)+(offset)))\n"
"#define GROUP_WRITE(ptr,offset,val) intel_sub_group_block_write_us((const __global ushort*)(ptr)+(offset),as_ushort(val))\n"
"#define GROUP_WRITE2(ptr,offset,val) intel_sub_group_block_write_us2((const __global ushort*)(ptr)+(offset),as_ushort2(val))\n"
"#define GROUP_WRITE4(ptr,offset,val) intel_sub_group_block_write_us4((const __global ushort*)(ptr)+(offset),as_ushort4(val))\n"
"#define GROUP_WRITE8(ptr,offset,val) intel_sub_group_block_write_us8((const __global ushort*)(ptr)+(offset),as_ushort8(val))\n"
"#define GROUP_SHUFFLE(data,id) as_half(intel_sub_group_shuffle(as_ushort(data),id))\n"
"#define GROUP_SHUFFLE2(data,id) as_half2(intel_sub_group_shuffle(as_ushort2(data),id))\n"
"#define GROUP_SHUFFLE4(data,id) as_half4(intel_sub_group_shuffle(as_ushort4(data),id))\n"
"#define GROUP_SHUFFLE8(data,id) as_half8(intel_sub_group_shuffle(as_ushort8(data),id))\n"
"#else\n"
"#define GROUP_READ(ptr,offset) as_float(intel_sub_group_block_read((const __global uint*)(ptr)+(offset)))\n"
"#define GROUP_READ2(ptr,offset) as_float2(intel_sub_group_block_read2((const __global uint*)(ptr)+(offset)))\n"
"#define GROUP_READ4(ptr,offset) as_float4(intel_sub_group_block_read4((const __global uint*)(ptr)+(offset)))\n"
"#define GROUP_READ8(ptr,offset) as_float8(intel_sub_group_block_read8((const __global uint*)(ptr)+(offset)))\n"
"#define GROUP_WRITE(ptr,offset,val) intel_sub_group_block_write((const __global uint*)(ptr)+(offset),as_uint(val))\n"
"#define GROUP_WRITE2(ptr,offset,val) intel_sub_group_block_write2((const __global uint*)(ptr)+(offset),as_uint2(val))\n"
"#define GROUP_WRITE4(ptr,offset,val) intel_sub_group_block_write4((const __global uint*)(ptr)+(offset),as_uint4(val))\n"
"#define GROUP_WRITE8(ptr,offset,val) intel_sub_group_block_write8((const __global uint*)(ptr)+(offset),as_uint8(val))\n"
"#define GROUP_SHUFFLE(data,id) intel_sub_group_shuffle(data,id)\n"
"#define GROUP_SHUFFLE2(data,id) intel_sub_group_shuffle(data,id)\n"
"#define GROUP_SHUFFLE4(data,id) intel_sub_group_shuffle(data,id)\n"
"#define GROUP_SHUFFLE8(data,id) intel_sub_group_shuffle(data,id)\n"
"#endif\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void conv_2d_buf_subgroup_c1_c4_b2(\n"
" __global FLOAT* input,\n"
" __global FLOAT* output,\n"
" __global FLOAT* weights,\n"
" __global FLOAT* biases,\n"
" __private const int pad_width,\n"
" __private const int pad_height,\n"
" __private const int input_width,\n"
" __private const int input_height,\n"
" __private const int output_width,\n"
" __private const int output_height,\n"
" __private const int output_channel,\n"
" __private const int batch,\n"
" __private const int x_blocks,\n"
" __private const int input_pad_left,\n"
" __private const int input_pad_right,\n"
" __private const int output_pad_left,\n"
" __private const int output_pad_right\n"
")\n"
"{\n"
" const int f_block=get_group_id(1);\n"
" const int lid=get_sub_group_local_id();\n"
" const int b=get_global_id(2);\n"
" const int xy=get_global_id(0);\n"
" const int x=(xy % x_blocks) << 1;\n"
" const int y=(xy/x_blocks);\n"
" const int input_x=x*STRIDE_WIDTH-pad_width;\n"
" const int input_y=y*STRIDE_HEIGHT-pad_height;\n"
" const uint input_x_pitch=1;\n"
" const uint input_y_pitch=input_x_pitch*input_width;\n"
" const uint input_f_pitch=input_y_pitch*input_height;\n"
" const uint input_b_pitch=input_f_pitch*INPUT_CHANNEL;\n"
" const uint input_offset=b*input_b_pitch +\n"
" input_y*input_y_pitch +\n"
" input_x*input_x_pitch;\n"
" const uint output_pack=(output_channel+3)/4;\n"
" const uint output_x_pitch=4;\n"
" const uint output_y_pitch=output_x_pitch*output_width;\n"
" const uint output_fs_pitch=output_y_pitch*output_height;\n"
" const uint output_b_pitch=output_fs_pitch*batch;\n"
" \n"
" \n"
" const uint output_offset=b*output_fs_pitch +\n"
" f_block*4*output_b_pitch +\n"
" y*output_y_pitch +\n"
" x*output_x_pitch;\n"
" const uint filter_isv_pitch=16;\n"
" const uint filter_x_pitch=256;\n"
" const uint filter_y_pitch=filter_x_pitch*FILTER_WIDTH;\n"
" const uint filter_is_pitch=filter_y_pitch*FILTER_HEIGHT;\n"
" const uint filter_os_pitch=filter_is_pitch*((INPUT_CHANNEL+15)/16);\n"
" const uint filter_offset=f_block*filter_os_pitch;\n"
" uint bias_offset=f_block*16;\n"
" COMPUTE_FLOAT2 dst=(COMPUTE_FLOAT2)(GROUP_READ(biases,bias_offset));\n"
" \n"
" FLOAT line_cache[INPUT_CHANNEL*INPUT_BLOCK_SIZE];\n"
" for (int ic=0; ic<INPUT_CHANNEL; ic++)\n"
" {\n"
" __attribute__((opencl_unroll_hint(INPUT_BLOCK_SIZE)))\n"
" for (int i=0; i<INPUT_BLOCK_SIZE; i++)\n"
" {\n"
" const int in_elem=i*16+lid;\n"
" const int xb=in_elem % INPUT_LINE_SIZE;\n"
" const int yb=in_elem/INPUT_LINE_SIZE;\n"
" if (input_y+yb >= 0 && input_y+yb<input_height &&\n"
" input_x+xb >= 0 && input_x+xb<input_width)\n"
" line_cache[ic*INPUT_BLOCK_SIZE+i]=input[input_offset +\n"
" ic*input_f_pitch +\n"
" xb*input_x_pitch +\n"
" yb*input_y_pitch];\n"
" else\n"
" line_cache[ic*INPUT_BLOCK_SIZE+i]=0;\n"
" }\n"
" }\n"
" __attribute__((opencl_unroll_hint(FILTER_HEIGHT)))\n"
" for (int kh=0; kh<FILTER_HEIGHT; kh++)\n"
" {\n"
" __attribute__((opencl_unroll_hint(FILTER_WIDTH)))\n"
" for (int kw=0; kw<FILTER_WIDTH; kw++)\n"
" {\n"
" uint offset=filter_offset+kh*filter_y_pitch+kw*filter_x_pitch;\n"
" \n"
" COMPUTE_FLOAT wei[INPUT_CHANNEL];\n"
" __attribute__((opencl_unroll_hint(INPUT_CHANNEL)))\n"
" for (int ic=0; ic<INPUT_CHANNEL; ic++)\n"
" wei[ic]=GROUP_READ(weights,offset+ic*filter_isv_pitch);\n"
" \n"
" __attribute__((opencl_unroll_hint(2)))\n"
" for (int i=0; i<2; i++)\n"
" {\n"
" const uint buf_offset=(kw*DILATION_WIDTH+STRIDE_WIDTH*i+(kh*DILATION_HEIGHT)*INPUT_LINE_SIZE)/16;\n"
" const uint buf_group=(kw*DILATION_WIDTH+STRIDE_WIDTH*i+(kh*DILATION_HEIGHT)*INPUT_LINE_SIZE) % 16;\n"
" \n"
" for (int ic=0; ic<INPUT_CHANNEL; ic++) {\n"
" COMPUTE_FLOAT src=GROUP_SHUFFLE(line_cache[ic*INPUT_BLOCK_SIZE+buf_offset],buf_group);\n"
" dst[i]=mad(wei[ic],src,dst[i]);\n"
" }\n"
" }\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" dst=fmax(dst,(COMPUTE_FLOAT2)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" dst=clamp(dst,(COMPUTE_FLOAT2)0,(COMPUTE_FLOAT2)6);\n"
"#endif\n"
" const uint lid_x=lid % 4;\n"
" const uint lid_y=lid/4;\n"
" if ((f_block+1)*16 >= output_channel) {\n"
" for (int i=0; i<2 && (x+i)<output_width; i++) {\n"
" if ((f_block*16+lid_y*4<output_pack*4))\n"
" output[output_offset+lid_y*output_b_pitch+i*output_x_pitch+lid_x]=(FLOAT)dst[i];\n"
" }\n"
" }\n"
" else\n"
" {\n"
" for (int i=0; i<2 && (x+i)<output_width; i++) {\n"
" output[output_offset+lid_y*output_b_pitch+i*output_x_pitch+lid_x]=(FLOAT)dst[i];\n"
" }\n"
" }\n"
"}\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void conv_2d_buf_subgroup_c1_c4_b4(\n"
" __global FLOAT* input,\n"
" __global FLOAT* output,\n"
" __global FLOAT* weights,\n"
" __global FLOAT* biases,\n"
" __private const int pad_width,\n"
" __private const int pad_height,\n"
" __private const int input_width,\n"
" __private const int input_height,\n"
" __private const int output_width,\n"
" __private const int output_height,\n"
" __private const int output_channel,\n"
" __private const int batch,\n"
" __private const int x_blocks,\n"
" __private const int input_pad_left,\n"
" __private const int input_pad_right,\n"
" __private const int output_pad_left,\n"
" __private const int output_pad_right\n"
")\n"
"{\n"
" const int f_block=get_group_id(1);\n"
" const int lid=get_sub_group_local_id();\n"
" const int b=get_global_id(2);\n"
" const int xy=get_global_id(0);\n"
" const int x=(xy % x_blocks) << 2;\n"
" const int y=(xy/x_blocks);\n"
" const int input_x=x*STRIDE_WIDTH-pad_width;\n"
" const int input_y=y*STRIDE_HEIGHT-pad_height;\n"
" const uint input_x_pitch=1;\n"
" const uint input_y_pitch=input_x_pitch*input_width;\n"
" const uint input_f_pitch=input_y_pitch*input_height;\n"
" const uint input_b_pitch=input_f_pitch*INPUT_CHANNEL;\n"
" const uint input_offset=b*input_b_pitch +\n"
" input_y*input_y_pitch +\n"
" input_x*input_x_pitch;\n"
" const uint output_pack=(output_channel+3)/4;\n"
" const uint output_x_pitch=4;\n"
" const uint output_y_pitch=output_x_pitch*output_width;\n"
" const uint output_fs_pitch=output_y_pitch*output_height;\n"
" const uint output_b_pitch=output_fs_pitch*batch;\n"
" \n"
" \n"
" const uint output_offset=b*output_fs_pitch +\n"
" f_block*4*output_b_pitch +\n"
" y*output_y_pitch +\n"
" x*output_x_pitch;\n"
" const uint filter_isv_pitch=16;\n"
" const uint filter_x_pitch=256;\n"
" const uint filter_y_pitch=filter_x_pitch*FILTER_WIDTH;\n"
" const uint filter_is_pitch=filter_y_pitch*FILTER_HEIGHT;\n"
" const uint filter_os_pitch=filter_is_pitch*((INPUT_CHANNEL+15)/16);\n"
" const uint filter_offset=f_block*filter_os_pitch;\n"
" uint bias_offset=f_block*16;\n"
" COMPUTE_FLOAT4 dst=(COMPUTE_FLOAT4)(GROUP_READ(biases,bias_offset));\n"
" \n"
" FLOAT line_cache[INPUT_CHANNEL*INPUT_BLOCK_SIZE];\n"
" for (int ic=0; ic<INPUT_CHANNEL; ic++)\n"
" {\n"
" __attribute__((opencl_unroll_hint(INPUT_BLOCK_SIZE)))\n"
" for (int i=0; i<INPUT_BLOCK_SIZE; i++)\n"
" {\n"
" const int in_elem=i*16+lid;\n"
" const int xb=in_elem % INPUT_LINE_SIZE;\n"
" const int yb=in_elem/INPUT_LINE_SIZE;\n"
" if (input_y+yb >= 0 && input_y+yb<input_height &&\n"
" input_x+xb >= 0 && input_x+xb<input_width)\n"
" line_cache[ic*INPUT_BLOCK_SIZE+i]=input[input_offset +\n"
" ic*input_f_pitch +\n"
" xb*input_x_pitch +\n"
" yb*input_y_pitch];\n"
" else\n"
" line_cache[ic*INPUT_BLOCK_SIZE+i]=0;\n"
" }\n"
" }\n"
" __attribute__((opencl_unroll_hint(FILTER_HEIGHT)))\n"
" for (int kh=0; kh<FILTER_HEIGHT; kh++)\n"
" {\n"
" __attribute__((opencl_unroll_hint(FILTER_WIDTH)))\n"
" for (int kw=0; kw<FILTER_WIDTH; kw++)\n"
" {\n"
" uint offset=filter_offset+kh*filter_y_pitch+kw*filter_x_pitch;\n"
" \n"
" COMPUTE_FLOAT wei[INPUT_CHANNEL];\n"
" __attribute__((opencl_unroll_hint(INPUT_CHANNEL)))\n"
" for (int ic=0; ic<INPUT_CHANNEL; ic++)\n"
" wei[ic]=GROUP_READ(weights,offset+ic*filter_isv_pitch);\n"
" \n"
" __attribute__((opencl_unroll_hint(4)))\n"
" for (int i=0; i<4; i++)\n"
" {\n"
" const uint buf_offset=(kw*DILATION_WIDTH+STRIDE_WIDTH*i+(kh*DILATION_HEIGHT)*INPUT_LINE_SIZE)/16;\n"
" const uint buf_group=(kw*DILATION_WIDTH+STRIDE_WIDTH*i+(kh*DILATION_HEIGHT)*INPUT_LINE_SIZE) % 16;\n"
" \n"
" for (int ic=0; ic<INPUT_CHANNEL; ic++) {\n"
" COMPUTE_FLOAT src=GROUP_SHUFFLE(line_cache[ic*INPUT_BLOCK_SIZE+buf_offset],buf_group);\n"
" dst[i]=mad(wei[ic],src,dst[i]);\n"
" }\n"
" }\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" dst=fmax(dst,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" dst=clamp(dst,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" const uint lid_x=lid % 4;\n"
" const uint lid_y=lid/4;\n"
" if ((f_block+1)*16 >= output_channel) {\n"
" for (int i=0; i<4 && (x+i)<output_width; i++) {\n"
" if ((f_block*16+lid_y*4<output_pack*4))\n"
" output[output_offset+lid_y*output_b_pitch+i*output_x_pitch+lid_x]=(FLOAT)dst[i];\n"
" }\n"
" }\n"
" else\n"
" {\n"
" for (int i=0; i<4 && (x+i)<output_width; i++) {\n"
" output[output_offset+lid_y*output_b_pitch+i*output_x_pitch+lid_x]=(FLOAT)dst[i];\n"
" }\n"
" }\n"
"}\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void conv_2d_buf_subgroup_c1_c4_b8(\n"
" __global FLOAT* input,\n"
" __global FLOAT* output,\n"
" __global FLOAT* weights,\n"
" __global FLOAT* biases,\n"
" __private const int pad_width,\n"
" __private const int pad_height,\n"
" __private const int input_width,\n"
" __private const int input_height,\n"
" __private const int output_width,\n"
" __private const int output_height,\n"
" __private const int output_channel,\n"
" __private const int batch,\n"
" __private const int x_blocks,\n"
" __private const int input_pad_left,\n"
" __private const int input_pad_right,\n"
" __private const int output_pad_left,\n"
" __private const int output_pad_right\n"
")\n"
"{\n"
" const int f_block=get_group_id(1);\n"
" const int lid=get_sub_group_local_id();\n"
" const int b=get_global_id(2);\n"
" const int xy=get_global_id(0);\n"
" const int x=(xy % x_blocks) << 3;\n"
" const int y=(xy/x_blocks);\n"
" const int input_x=x*STRIDE_WIDTH-pad_width;\n"
" const int input_y=y*STRIDE_HEIGHT-pad_height;\n"
" const uint input_x_pitch=1;\n"
" const uint input_y_pitch=input_x_pitch*input_width;\n"
" const uint input_f_pitch=input_y_pitch*input_height;\n"
" const uint input_b_pitch=input_f_pitch*INPUT_CHANNEL;\n"
" const uint input_offset=b*input_b_pitch +\n"
" input_y*input_y_pitch +\n"
" input_x*input_x_pitch;\n"
" const uint output_pack=(output_channel+3)/4;\n"
" const uint output_x_pitch=4;\n"
" const uint output_y_pitch=output_x_pitch*output_width;\n"
" const uint output_fs_pitch=output_y_pitch*output_height;\n"
" const uint output_b_pitch=output_fs_pitch*batch;\n"
" \n"
" \n"
" const uint output_offset=b*output_fs_pitch +\n"
" f_block*4*output_b_pitch +\n"
" y*output_y_pitch +\n"
" x*output_x_pitch;\n"
" const uint filter_isv_pitch=16;\n"
" const uint filter_x_pitch=256;\n"
" const uint filter_y_pitch=filter_x_pitch*FILTER_WIDTH;\n"
" const uint filter_is_pitch=filter_y_pitch*FILTER_HEIGHT;\n"
" const uint filter_os_pitch=filter_is_pitch*((INPUT_CHANNEL+15)/16);\n"
" const uint filter_offset=f_block*filter_os_pitch;\n"
" uint bias_offset=f_block*16;\n"
" COMPUTE_FLOAT8 dst=(COMPUTE_FLOAT8)(GROUP_READ(biases,bias_offset));\n"
" \n"
" FLOAT line_cache[INPUT_CHANNEL*INPUT_BLOCK_SIZE];\n"
" for (int ic=0; ic<INPUT_CHANNEL; ic++)\n"
" {\n"
" __attribute__((opencl_unroll_hint(INPUT_BLOCK_SIZE)))\n"
" for (int i=0; i<INPUT_BLOCK_SIZE; i++)\n"
" {\n"
" const int in_elem=i*16+lid;\n"
" const int xb=in_elem % INPUT_LINE_SIZE;\n"
" const int yb=in_elem/INPUT_LINE_SIZE;\n"
" if (input_y+yb >= 0 && input_y+yb<input_height &&\n"
" input_x+xb >= 0 && input_x+xb<input_width)\n"
" line_cache[ic*INPUT_BLOCK_SIZE+i]=input[input_offset +\n"
" ic*input_f_pitch +\n"
" xb*input_x_pitch +\n"
" yb*input_y_pitch];\n"
" else\n"
" line_cache[ic*INPUT_BLOCK_SIZE+i]=0;\n"
" }\n"
" }\n"
" __attribute__((opencl_unroll_hint(FILTER_HEIGHT)))\n"
" for (int kh=0; kh<FILTER_HEIGHT; kh++)\n"
" {\n"
" __attribute__((opencl_unroll_hint(FILTER_WIDTH)))\n"
" for (int kw=0; kw<FILTER_WIDTH; kw++)\n"
" {\n"
" uint offset=filter_offset+kh*filter_y_pitch+kw*filter_x_pitch;\n"
" \n"
" COMPUTE_FLOAT wei[INPUT_CHANNEL];\n"
" __attribute__((opencl_unroll_hint(INPUT_CHANNEL)))\n"
" for (int ic=0; ic<INPUT_CHANNEL; ic++)\n"
" wei[ic]=GROUP_READ(weights,offset+ic*filter_isv_pitch);\n"
" \n"
" __attribute__((opencl_unroll_hint(8)))\n"
" for (int i=0; i<8; i++)\n"
" {\n"
" const uint buf_offset=(kw*DILATION_WIDTH+STRIDE_WIDTH*i+(kh*DILATION_HEIGHT)*INPUT_LINE_SIZE)/16;\n"
" const uint buf_group=(kw*DILATION_WIDTH+STRIDE_WIDTH*i+(kh*DILATION_HEIGHT)*INPUT_LINE_SIZE) % 16;\n"
" \n"
" for (int ic=0; ic<INPUT_CHANNEL; ic++) {\n"
" COMPUTE_FLOAT src=GROUP_SHUFFLE(line_cache[ic*INPUT_BLOCK_SIZE+buf_offset],buf_group);\n"
" dst[i]=mad(wei[ic],src,dst[i]);\n"
" }\n"
" }\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" dst=fmax(dst,(COMPUTE_FLOAT8)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" dst=clamp(dst,(COMPUTE_FLOAT8)0,(COMPUTE_FLOAT8)6);\n"
"#endif\n"
" const uint lid_x=lid % 4;\n"
" const uint lid_y=lid/4;\n"
" if ((f_block+1)*16 >= output_channel) {\n"
" for (int i=0; i<8 && (x+i)<output_width; i++) {\n"
" if ((f_block*16+lid_y*4<output_pack*4))\n"
" output[output_offset+lid_y*output_b_pitch+i*output_x_pitch+lid_x]=(FLOAT)dst[i];\n"
" }\n"
" }\n"
" else\n"
" {\n"
" for (int i=0; i<8 && (x+i)<output_width; i++) {\n"
" output[output_offset+lid_y*output_b_pitch+i*output_x_pitch+lid_x]=(FLOAT)dst[i];\n"
" }\n"
" }\n"
"}\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void conv_2d_buf_subgroup_c1_c16_b2(\n"
" __global FLOAT* input,\n"
" __global FLOAT* output,\n"
" __global FLOAT* weights,\n"
" __global FLOAT* biases,\n"
" __private const int pad_width,\n"
" __private const int pad_height,\n"
" __private const int input_width,\n"
" __private const int input_height,\n"
" __private const int output_width,\n"
" __private const int output_height,\n"
" __private const int output_channel,\n"
" __private const int batch,\n"
" __private const int x_blocks,\n"
" __private const int input_pad_left,\n"
" __private const int input_pad_right,\n"
" __private const int output_pad_left,\n"
" __private const int output_pad_right\n"
")\n"
"{\n"
" const int f_block=get_group_id(1);\n"
" const int lid=get_sub_group_local_id();\n"
" const int b=get_global_id(2);\n"
" const int xy=get_global_id(0);\n"
" const int x=(xy % x_blocks) << 1;\n"
" const int y=(xy/x_blocks);\n"
" const int input_x=x*STRIDE_WIDTH-pad_width;\n"
" const int input_y=y*STRIDE_HEIGHT-pad_height;\n"
" const uint input_x_pitch=1;\n"
" const uint input_y_pitch=input_x_pitch*input_width;\n"
" const uint input_f_pitch=input_y_pitch*input_height;\n"
" const uint input_b_pitch=input_f_pitch*INPUT_CHANNEL;\n"
" const uint input_offset=b*input_b_pitch +\n"
" input_y*input_y_pitch +\n"
" input_x*input_x_pitch;\n"
" const uint output_x_pitch=16;\n"
" const uint output_y_pitch=output_x_pitch*(output_pad_left+output_width+output_pad_right);\n"
" const uint output_fs_pitch=output_y_pitch*output_height;\n"
" const uint output_b_pitch=output_fs_pitch*((output_channel+15)/16);\n"
" \n"
" \n"
" const uint output_offset=b*output_b_pitch +\n"
" f_block*output_fs_pitch +\n"
" y*output_y_pitch +\n"
" (x+output_pad_left)*output_x_pitch;\n"
" const uint filter_isv_pitch=16;\n"
" const uint filter_x_pitch=256;\n"
" const uint filter_y_pitch=filter_x_pitch*FILTER_WIDTH;\n"
" const uint filter_is_pitch=filter_y_pitch*FILTER_HEIGHT;\n"
" const uint filter_os_pitch=filter_is_pitch*((INPUT_CHANNEL+15)/16);\n"
" const uint filter_offset=f_block*filter_os_pitch;\n"
" uint bias_offset=f_block*16;\n"
" COMPUTE_FLOAT2 dst=(COMPUTE_FLOAT2)(GROUP_READ(biases,bias_offset));\n"
" \n"
" FLOAT line_cache[INPUT_CHANNEL*INPUT_BLOCK_SIZE];\n"
" for (int ic=0; ic<INPUT_CHANNEL; ic++)\n"
" {\n"
" __attribute__((opencl_unroll_hint(INPUT_BLOCK_SIZE)))\n"
" for (int i=0; i<INPUT_BLOCK_SIZE; i++)\n"
" {\n"
" const int in_elem=i*16+lid;\n"
" const int xb=in_elem % INPUT_LINE_SIZE;\n"
" const int yb=in_elem/INPUT_LINE_SIZE;\n"
" if (input_y+yb >= 0 && input_y+yb<input_height &&\n"
" input_x+xb >= 0 && input_x+xb<input_width)\n"
" line_cache[ic*INPUT_BLOCK_SIZE+i]=input[input_offset +\n"
" ic*input_f_pitch +\n"
" xb*input_x_pitch +\n"
" yb*input_y_pitch];\n"
" else\n"
" line_cache[ic*INPUT_BLOCK_SIZE+i]=0;\n"
" }\n"
" }\n"
" __attribute__((opencl_unroll_hint(FILTER_HEIGHT)))\n"
" for (int kh=0; kh<FILTER_HEIGHT; kh++)\n"
" {\n"
" __attribute__((opencl_unroll_hint(FILTER_WIDTH)))\n"
" for (int kw=0; kw<FILTER_WIDTH; kw++)\n"
" {\n"
" uint offset=filter_offset+kh*filter_y_pitch+kw*filter_x_pitch;\n"
" \n"
" COMPUTE_FLOAT wei[INPUT_CHANNEL];\n"
" __attribute__((opencl_unroll_hint(INPUT_CHANNEL)))\n"
" for (int ic=0; ic<INPUT_CHANNEL; ic++)\n"
" wei[ic]=GROUP_READ(weights,offset+ic*filter_isv_pitch);\n"
" \n"
" __attribute__((opencl_unroll_hint(2)))\n"
" for (int i=0; i<2; i++)\n"
" {\n"
" const uint buf_offset=(kw*DILATION_WIDTH+STRIDE_WIDTH*i+(kh*DILATION_HEIGHT)*INPUT_LINE_SIZE)/16;\n"
" const uint buf_group=(kw*DILATION_WIDTH+STRIDE_WIDTH*i+(kh*DILATION_HEIGHT)*INPUT_LINE_SIZE) % 16;\n"
" \n"
" for (int ic=0; ic<INPUT_CHANNEL; ic++) {\n"
" COMPUTE_FLOAT src=GROUP_SHUFFLE(line_cache[ic*INPUT_BLOCK_SIZE+buf_offset],buf_group);\n"
" dst[i]=mad(wei[ic],src,dst[i]);\n"
" }\n"
" }\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" dst=fmax(dst,(COMPUTE_FLOAT2)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" dst=clamp(dst,(COMPUTE_FLOAT2)0,(COMPUTE_FLOAT2)6);\n"
"#endif\n"
" if(x == 0){\n"
" uint pad_offset=b*output_b_pitch+f_block*output_fs_pitch+y*output_y_pitch;\n"
" for(int i=0; i<output_pad_left; ++i){\n"
" output[pad_offset+i*output_x_pitch+lid]=0;\n"
" }\n"
" pad_offset += (output_width+output_pad_left)*output_x_pitch;\n"
" for(int i=0; i<output_pad_right; ++i){\n"
" output[pad_offset+i*output_x_pitch+lid]=0;\n"
" }\n"
" }\n"
" if ((f_block+1)*16 >= output_channel) {\n"
" for (int i=0; i<2; i++) {\n"
" if ((f_block*16+lid<output_channel) && (x+i)<output_width)\n"
" output[output_offset+i*output_x_pitch+lid]=(FLOAT)dst[i];\n"
" }\n"
" }\n"
" else\n"
" {\n"
" if (x+2 <= output_width || output_width % 2 == 0) {\n"
" GROUP_WRITE2(output,output_offset,CONVERT_FLOAT2(dst));\n"
" }else{\n"
" for (int i=0; i<output_width % 2; i++) {\n"
" output[output_offset+i*output_x_pitch+lid]=(FLOAT)dst[i];\n"
" }\n"
" }\n"
" }\n"
"}\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void conv_2d_buf_subgroup_c1_c16_b4(\n"
" __global FLOAT* input,\n"
" __global FLOAT* output,\n"
" __global FLOAT* weights,\n"
" __global FLOAT* biases,\n"
" __private const int pad_width,\n"
" __private const int pad_height,\n"
" __private const int input_width,\n"
" __private const int input_height,\n"
" __private const int output_width,\n"
" __private const int output_height,\n"
" __private const int output_channel,\n"
" __private const int batch,\n"
" __private const int x_blocks,\n"
" __private const int input_pad_left,\n"
" __private const int input_pad_right,\n"
" __private const int output_pad_left,\n"
" __private const int output_pad_right\n"
")\n"
"{\n"
" const int f_block=get_group_id(1);\n"
" const int lid=get_sub_group_local_id();\n"
" const int b=get_global_id(2);\n"
" const int xy=get_global_id(0);\n"
" const int x=(xy % x_blocks) << 2;\n"
" const int y=(xy/x_blocks);\n"
" const int input_x=x*STRIDE_WIDTH-pad_width;\n"
" const int input_y=y*STRIDE_HEIGHT-pad_height;\n"
" const uint input_x_pitch=1;\n"
" const uint input_y_pitch=input_x_pitch*input_width;\n"
" const uint input_f_pitch=input_y_pitch*input_height;\n"
" const uint input_b_pitch=input_f_pitch*INPUT_CHANNEL;\n"
" const uint input_offset=b*input_b_pitch +\n"
" input_y*input_y_pitch +\n"
" input_x*input_x_pitch;\n"
" const uint output_x_pitch=16;\n"
" const uint output_y_pitch=output_x_pitch*(output_pad_left+output_width+output_pad_right);\n"
" const uint output_fs_pitch=output_y_pitch*output_height;\n"
" const uint output_b_pitch=output_fs_pitch*((output_channel+15)/16);\n"
" \n"
" \n"
" const uint output_offset=b*output_b_pitch +\n"
" f_block*output_fs_pitch +\n"
" y*output_y_pitch +\n"
" (x+output_pad_left)*output_x_pitch;\n"
" const uint filter_isv_pitch=16;\n"
" const uint filter_x_pitch=256;\n"
" const uint filter_y_pitch=filter_x_pitch*FILTER_WIDTH;\n"
" const uint filter_is_pitch=filter_y_pitch*FILTER_HEIGHT;\n"
" const uint filter_os_pitch=filter_is_pitch*((INPUT_CHANNEL+15)/16);\n"
" const uint filter_offset=f_block*filter_os_pitch;\n"
" uint bias_offset=f_block*16;\n"
" COMPUTE_FLOAT4 dst=(COMPUTE_FLOAT4)(GROUP_READ(biases,bias_offset));\n"
" \n"
" FLOAT line_cache[INPUT_CHANNEL*INPUT_BLOCK_SIZE];\n"
" for (int ic=0; ic<INPUT_CHANNEL; ic++)\n"
" {\n"
" __attribute__((opencl_unroll_hint(INPUT_BLOCK_SIZE)))\n"
" for (int i=0; i<INPUT_BLOCK_SIZE; i++)\n"
" {\n"
" const int in_elem=i*16+lid;\n"
" const int xb=in_elem % INPUT_LINE_SIZE;\n"
" const int yb=in_elem/INPUT_LINE_SIZE;\n"
" if (input_y+yb >= 0 && input_y+yb<input_height &&\n"
" input_x+xb >= 0 && input_x+xb<input_width)\n"
" line_cache[ic*INPUT_BLOCK_SIZE+i]=input[input_offset +\n"
" ic*input_f_pitch +\n"
" xb*input_x_pitch +\n"
" yb*input_y_pitch];\n"
" else\n"
" line_cache[ic*INPUT_BLOCK_SIZE+i]=0;\n"
" }\n"
" }\n"
" __attribute__((opencl_unroll_hint(FILTER_HEIGHT)))\n"
" for (int kh=0; kh<FILTER_HEIGHT; kh++)\n"
" {\n"
" __attribute__((opencl_unroll_hint(FILTER_WIDTH)))\n"
" for (int kw=0; kw<FILTER_WIDTH; kw++)\n"
" {\n"
" uint offset=filter_offset+kh*filter_y_pitch+kw*filter_x_pitch;\n"
" \n"
" COMPUTE_FLOAT wei[INPUT_CHANNEL];\n"
" __attribute__((opencl_unroll_hint(INPUT_CHANNEL)))\n"
" for (int ic=0; ic<INPUT_CHANNEL; ic++)\n"
" wei[ic]=GROUP_READ(weights,offset+ic*filter_isv_pitch);\n"
" \n"
" __attribute__((opencl_unroll_hint(4)))\n"
" for (int i=0; i<4; i++)\n"
" {\n"
" const uint buf_offset=(kw*DILATION_WIDTH+STRIDE_WIDTH*i+(kh*DILATION_HEIGHT)*INPUT_LINE_SIZE)/16;\n"
" const uint buf_group=(kw*DILATION_WIDTH+STRIDE_WIDTH*i+(kh*DILATION_HEIGHT)*INPUT_LINE_SIZE) % 16;\n"
" \n"
" for (int ic=0; ic<INPUT_CHANNEL; ic++) {\n"
" COMPUTE_FLOAT src=GROUP_SHUFFLE(line_cache[ic*INPUT_BLOCK_SIZE+buf_offset],buf_group);\n"
" dst[i]=mad(wei[ic],src,dst[i]);\n"
" }\n"
" }\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" dst=fmax(dst,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" dst=clamp(dst,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" if(x == 0){\n"
" uint pad_offset=b*output_b_pitch+f_block*output_fs_pitch+y*output_y_pitch;\n"
" for(int i=0; i<output_pad_left; ++i){\n"
" output[pad_offset+i*output_x_pitch+lid]=0;\n"
" }\n"
" pad_offset += (output_width+output_pad_left)*output_x_pitch;\n"
" for(int i=0; i<output_pad_right; ++i){\n"
" output[pad_offset+i*output_x_pitch+lid]=0;\n"
" }\n"
" }\n"
" if ((f_block+1)*16 >= output_channel) {\n"
" for (int i=0; i<4; i++) {\n"
" if ((f_block*16+lid<output_channel) && (x+i)<output_width)\n"
" output[output_offset+i*output_x_pitch+lid]=(FLOAT)dst[i];\n"
" }\n"
" }\n"
" else\n"
" {\n"
" if (x+4 <= output_width || output_width % 4 == 0) {\n"
" GROUP_WRITE4(output,output_offset,CONVERT_FLOAT4(dst));\n"
" }else{\n"
" for (int i=0; i<output_width % 4; i++) {\n"
" output[output_offset+i*output_x_pitch+lid]=(FLOAT)dst[i];\n"
" }\n"
" }\n"
" }\n"
"}\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void conv_2d_buf_subgroup_c1_c16_b8(\n"
" __global FLOAT* input,\n"
" __global FLOAT* output,\n"
" __global FLOAT* weights,\n"
" __global FLOAT* biases,\n"
" __private const int pad_width,\n"
" __private const int pad_height,\n"
" __private const int input_width,\n"
" __private const int input_height,\n"
" __private const int output_width,\n"
" __private const int output_height,\n"
" __private const int output_channel,\n"
" __private const int batch,\n"
" __private const int x_blocks,\n"
" __private const int input_pad_left,\n"
" __private const int input_pad_right,\n"
" __private const int output_pad_left,\n"
" __private const int output_pad_right\n"
")\n"
"{\n"
" const int f_block=get_group_id(1);\n"
" const int lid=get_sub_group_local_id();\n"
" const int b=get_global_id(2);\n"
" const int xy=get_global_id(0);\n"
" const int x=(xy % x_blocks) << 3;\n"
" const int y=(xy/x_blocks);\n"
" const int input_x=x*STRIDE_WIDTH-pad_width;\n"
" const int input_y=y*STRIDE_HEIGHT-pad_height;\n"
" const uint input_x_pitch=1;\n"
" const uint input_y_pitch=input_x_pitch*input_width;\n"
" const uint input_f_pitch=input_y_pitch*input_height;\n"
" const uint input_b_pitch=input_f_pitch*INPUT_CHANNEL;\n"
" const uint input_offset=b*input_b_pitch +\n"
" input_y*input_y_pitch +\n"
" input_x*input_x_pitch;\n"
" const uint output_x_pitch=16;\n"
" const uint output_y_pitch=output_x_pitch*(output_pad_left+output_width+output_pad_right);\n"
" const uint output_fs_pitch=output_y_pitch*output_height;\n"
" const uint output_b_pitch=output_fs_pitch*((output_channel+15)/16);\n"
" \n"
" \n"
" const uint output_offset=b*output_b_pitch +\n"
" f_block*output_fs_pitch +\n"
" y*output_y_pitch +\n"
" (x+output_pad_left)*output_x_pitch;\n"
" const uint filter_isv_pitch=16;\n"
" const uint filter_x_pitch=256;\n"
" const uint filter_y_pitch=filter_x_pitch*FILTER_WIDTH;\n"
" const uint filter_is_pitch=filter_y_pitch*FILTER_HEIGHT;\n"
" const uint filter_os_pitch=filter_is_pitch*((INPUT_CHANNEL+15)/16);\n"
" const uint filter_offset=f_block*filter_os_pitch;\n"
" uint bias_offset=f_block*16;\n"
" COMPUTE_FLOAT8 dst=(COMPUTE_FLOAT8)(GROUP_READ(biases,bias_offset));\n"
" \n"
" FLOAT line_cache[INPUT_CHANNEL*INPUT_BLOCK_SIZE];\n"
" for (int ic=0; ic<INPUT_CHANNEL; ic++)\n"
" {\n"
" __attribute__((opencl_unroll_hint(INPUT_BLOCK_SIZE)))\n"
" for (int i=0; i<INPUT_BLOCK_SIZE; i++)\n"
" {\n"
" const int in_elem=i*16+lid;\n"
" const int xb=in_elem % INPUT_LINE_SIZE;\n"
" const int yb=in_elem/INPUT_LINE_SIZE;\n"
" if (input_y+yb >= 0 && input_y+yb<input_height &&\n"
" input_x+xb >= 0 && input_x+xb<input_width)\n"
" line_cache[ic*INPUT_BLOCK_SIZE+i]=input[input_offset +\n"
" ic*input_f_pitch +\n"
" xb*input_x_pitch +\n"
" yb*input_y_pitch];\n"
" else\n"
" line_cache[ic*INPUT_BLOCK_SIZE+i]=0;\n"
" }\n"
" }\n"
" __attribute__((opencl_unroll_hint(FILTER_HEIGHT)))\n"
" for (int kh=0; kh<FILTER_HEIGHT; kh++)\n"
" {\n"
" __attribute__((opencl_unroll_hint(FILTER_WIDTH)))\n"
" for (int kw=0; kw<FILTER_WIDTH; kw++)\n"
" {\n"
" uint offset=filter_offset+kh*filter_y_pitch+kw*filter_x_pitch;\n"
" \n"
" COMPUTE_FLOAT wei[INPUT_CHANNEL];\n"
" __attribute__((opencl_unroll_hint(INPUT_CHANNEL)))\n"
" for (int ic=0; ic<INPUT_CHANNEL; ic++)\n"
" wei[ic]=GROUP_READ(weights,offset+ic*filter_isv_pitch);\n"
" \n"
" __attribute__((opencl_unroll_hint(8)))\n"
" for (int i=0; i<8; i++)\n"
" {\n"
" const uint buf_offset=(kw*DILATION_WIDTH+STRIDE_WIDTH*i+(kh*DILATION_HEIGHT)*INPUT_LINE_SIZE)/16;\n"
" const uint buf_group=(kw*DILATION_WIDTH+STRIDE_WIDTH*i+(kh*DILATION_HEIGHT)*INPUT_LINE_SIZE) % 16;\n"
" \n"
" for (int ic=0; ic<INPUT_CHANNEL; ic++) {\n"
" COMPUTE_FLOAT src=GROUP_SHUFFLE(line_cache[ic*INPUT_BLOCK_SIZE+buf_offset],buf_group);\n"
" dst[i]=mad(wei[ic],src,dst[i]);\n"
" }\n"
" }\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" dst=fmax(dst,(COMPUTE_FLOAT8)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" dst=clamp(dst,(COMPUTE_FLOAT8)0,(COMPUTE_FLOAT8)6);\n"
"#endif\n"
" if(x == 0){\n"
" uint pad_offset=b*output_b_pitch+f_block*output_fs_pitch+y*output_y_pitch;\n"
" for(int i=0; i<output_pad_left; ++i){\n"
" output[pad_offset+i*output_x_pitch+lid]=0;\n"
" }\n"
" pad_offset += (output_width+output_pad_left)*output_x_pitch;\n"
" for(int i=0; i<output_pad_right; ++i){\n"
" output[pad_offset+i*output_x_pitch+lid]=0;\n"
" }\n"
" }\n"
" if ((f_block+1)*16 >= output_channel) {\n"
" for (int i=0; i<8; i++) {\n"
" if ((f_block*16+lid<output_channel) && (x+i)<output_width)\n"
" output[output_offset+i*output_x_pitch+lid]=(FLOAT)dst[i];\n"
" }\n"
" }\n"
" else\n"
" {\n"
" if (x+8 <= output_width || output_width % 8 == 0) {\n"
" GROUP_WRITE8(output,output_offset,CONVERT_FLOAT8(dst));\n"
" }else{\n"
" for (int i=0; i<output_width % 8; i++) {\n"
" output[output_offset+i*output_x_pitch+lid]=(FLOAT)dst[i];\n"
" }\n"
" }\n"
" }\n"
"}\n"
;
#endif
#endif
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* matmul_local_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"/*\n"
" "" #define OPWM 64 // The outputsize-per-workgroup in dimension M\n"
" #define OPWN 128 // The outputsize-per-workgroup in dimension N\n"
" #define CPWK 8 // The cachesize-per-workgroup in dimension K\n"
" #define OPTM 4 // The outputsize-per-thread in dimension M\n"
" #define OPTN 8 // The outputsize-per-thread in dimension N\n"
" */\n"
"#define TPWM (OPWM/OPTM) // The threadsize-per-workgroup in dimension M\n"
"#define TPWN (OPWN/OPTN) // The threadsize-per-workgroup in dimension N\n"
"#define LPTA ((CPWK*OPWM)/(TPWM*TPWN)) // Loads-num-per-thread for A\n"
"#define LPTB ((CPWK*OPWN)/(TPWM*TPWN)) // Loads-num-per-thread for B\n"
"// vetorize+pragma unroll\n"
"__kernel void matmul_local_buf(const int M,const int N,const int K,\n"
" __global const FLOAT* A,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char* B,\n"
" __global const float* dequantScale,\n"
" __global const float* dequantOffset,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar* B,\n"
" __global const float* dequantScale,\n"
" __global const float* dequantOffset,\n"
"#else\n"
" __global const FLOAT* B,\n"
"#endif\n"
"#ifdef BIAS\n"
" __global const FLOAT* bias,\n"
"#endif\n"
" __global FLOAT* C) {\n"
" // Local thread id\n"
" const int lidm=get_local_id(0); // Local row ID\n"
" const int lidn=get_local_id(1); // Local col ID\n"
" // group id\n"
" const int offsetM=get_group_id(0)*OPWM; // Work-group offset M\n"
" const int offsetN=get_group_id(1)*OPWN; // Work-group offset N\n"
" // Local memory for work-group cache of A and B\n"
" __local FLOAT Alocal[CPWK][OPWM];\n"
" __local FLOAT Blocal[OPWN][CPWK+2];\n"
" // Allocate register space\n"
" COMPUTE_FLOAT sum[OPTM][OPTN];\n"
" // Initialise the accumulation registers\n"
" for (int wm=0; wm<OPTM; wm++) {\n"
" for (int wn=0; wn<OPTN; wn++) {\n"
" sum[wm][wn]=0.0f;\n"
" }\n"
" }\n"
" \n"
" // Loop over all tiles\n"
" const int numLoops=K/CPWK;\n"
" int lid=lidn*TPWM+lidm;\n"
" for (int t=0; t<numLoops; t++) {\n"
" // Load one work-group of A and B into local memory\n"
" for (int la=0; la<LPTA; la++) {\n"
" int id=la*TPWN*TPWM+lid;\n"
" int row=id % OPWM;\n"
" int col=id/OPWM;\n"
" int tiledIndex=CPWK*t+col;\n"
" #ifdef TRANSPOSE_A\n"
" // [K,M]\n"
" Alocal[col][row]=A[tiledIndex*M+(offsetM+row)];\n"
" #else\n"
" // [M,K]\n"
" Alocal[col][row]=A[(offsetM+row)*K+tiledIndex];\n"
" #endif\n"
" }\n"
" for (int la=0; la<LPTB; la++) {\n"
" int id=la*TPWN*TPWM+lid;\n"
" int row=id % OPWN;\n"
" int col=id/OPWN;\n"
" int tiledIndex=CPWK*t+col;\n"
" #ifdef TRANSPOSE_B\n"
" // [N,K]\n"
" Blocal[row][col]=B[(offsetN+row)*K+tiledIndex];\n"
" #else\n"
" // [K,N]\n"
" Blocal[row][col]=B[tiledIndex*N+offsetN+row];\n"
" #endif\n"
" }\n"
" \n"
" // Synchronise to make sure the tile is loaded\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" // Loop over the values of a single tile\n"
" \n"
" // Perform the computation\n"
" FLOAT4 A_k0,B_k0[OPTN];\n"
" {\n"
" int row=lidm;\n"
" int col=lidn;\n"
" \n"
" A_k0.s0=Alocal[0][row];\n"
" A_k0.s1=Alocal[1][row];\n"
" A_k0.s2=Alocal[2][row];\n"
" A_k0.s3=Alocal[3][row];\n"
" \n"
" #pragma unroll\n"
" for (int wn=0; wn<OPTN; wn++) {\n"
" B_k0[wn].s0=Blocal[col][0];\n"
" B_k0[wn].s1=Blocal[col][1];\n"
" B_k0[wn].s2=Blocal[col][2];\n"
" B_k0[wn].s3=Blocal[col][3];\n"
" sum[0][wn] += dot(A_k0,B_k0[wn]);\n"
" col += TPWN;\n"
" }\n"
" \n"
" #pragma unroll\n"
" for(int wm=1; wm<OPTM; wm++) {\n"
" row += TPWM;\n"
" A_k0.s0=Alocal[0][row];\n"
" A_k0.s1=Alocal[1][row];\n"
" A_k0.s2=Alocal[2][row];\n"
" A_k0.s3=Alocal[3][row];\n"
" for (int wn=0; wn<OPTN; wn++) {\n"
" sum[wm][wn] += dot(A_k0,B_k0[wn]);\n"
" }\n"
" }\n"
" }\n"
" {\n"
" int col=lidn;\n"
" for (int wn=0; wn<OPTN; wn++) {\n"
" B_k0[wn].s0=Blocal[col][4];\n"
" B_k0[wn].s1=Blocal[col][5];\n"
" B_k0[wn].s2=Blocal[col][6];\n"
" B_k0[wn].s3=Blocal[col][7];\n"
" col += TPWN;\n"
" }\n"
" int row=lidm;\n"
" for (int wm=0; wm<OPTM; wm++) {\n"
" A_k0.s0=Alocal[4][row];\n"
" A_k0.s1=Alocal[5][row];\n"
" A_k0.s2=Alocal[6][row];\n"
" A_k0.s3=Alocal[7][row];\n"
" for (int wn=0; wn<OPTN; wn++) {\n"
" sum[wm][wn] += dot(A_k0,B_k0[wn]);\n"
" }\n"
" row += TPWM;\n"
" }\n"
" }\n"
" // Synchronise before loading the next tile\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" // Store the final results in C\n"
" for (int wm=0; wm<OPTM; wm++) {\n"
" int globalRow=offsetM+lidm+wm*TPWM;\n"
" for (int wn=0; wn<OPTN; wn++) {\n"
" int globalCol=offsetN+lidn+wn*TPWN;\n"
" #ifdef BIAS\n"
" sum[wm][wn] += bias[globalCol];\n"
" #endif\n"
" C[globalRow*N+globalCol]=sum[wm][wn];\n"
" }\n"
" }\n"
"}\n"
"// double buffer\n"
"__kernel void matmul_local_double_buf(const int M,const int N,const int K,\n"
" __global const FLOAT* A,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char* B,\n"
" __global const float* dequantScale,\n"
" __global const float* dequantOffset,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar* B,\n"
" __global const float* dequantScale,\n"
" __global const float* dequantOffset,\n"
"#else\n"
" __global const FLOAT* B,\n"
"#endif\n"
"#ifdef BIAS\n"
" __global const FLOAT* bias,\n"
"#endif\n"
" __global FLOAT* C) {\n"
" // Local thread id\n"
" const ushort lidm=get_local_id(0); // Local row ID\n"
" const ushort lidn=get_local_id(1); // Local col ID\n"
" // group id\n"
" const ushort offsetM=get_group_id(0)*OPWM; // Work-group offset M\n"
" const ushort offsetN=get_group_id(1)*OPWN; // Work-group offset N\n"
" // Local memory for work-group cache of A and B\n"
" __local FLOAT AlocalR[CPWK][OPWM];\n"
" __local FLOAT BlocalR[OPWN][CPWK+2];\n"
" __local FLOAT AlocalC[CPWK][OPWM];\n"
" __local FLOAT BlocalC[OPWN][CPWK+2];\n"
" \n"
" // Allocate register space\n"
" COMPUTE_FLOAT sum[OPTM][OPTN];\n"
" // Initialise the accumulation registers\n"
" for (ushort wm=0; wm<OPTM; wm++) {\n"
" for (ushort wn=0; wn<OPTN; wn++) {\n"
" sum[wm][wn]=0.0f;\n"
" }\n"
" }\n"
" \n"
" // Loop over all tiles\n"
" const ushort numLoops=K/CPWK;\n"
" ushort lid=lidn*TPWM+lidm;\n"
" for (ushort t=0; t<numLoops; t++) {\n"
" // Load one work-group of A and B into local memory\n"
" for (ushort la=0; la<LPTA; la++) {\n"
" ushort id=la*TPWN*TPWM+lid;\n"
" ushort row=id % OPWM;\n"
" ushort col=id/OPWM;\n"
" ushort tiledIndex=CPWK*t+col;\n"
" #ifdef TRANSPOSE_A\n"
" // [K,M]\n"
" AlocalR[col][row]=A[tiledIndex*M+(offsetM+row)];\n"
" #else\n"
" // [M,K]\n"
" AlocalR[col][row]=A[(offsetM+row)*K+tiledIndex];\n"
" #endif\n"
" }\n"
" for (ushort la=0; la<LPTB; la++) {\n"
" ushort id=la*TPWN*TPWM+lid;\n"
" ushort row=id % OPWN;\n"
" ushort col=id/OPWN;\n"
" ushort tiledIndex=CPWK*t+col;\n"
" #ifdef TRANSPOSE_B\n"
" // [N,K]\n"
" BlocalR[row][col]=B[(offsetN+row)*K+tiledIndex];\n"
" #else\n"
" // [K,N]\n"
" BlocalR[row][col]=B[tiledIndex*N+offsetN+row];\n"
" #endif\n"
" }\n"
" \n"
" // Synchronise to make sure the tile is loaded\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" \n"
" // Loop over the values of a single tile\n"
" \n"
" // Perform the computation\n"
" FLOAT4 A_k0,B_k0[OPTN];\n"
" {\n"
" ushort row=lidm;\n"
" ushort col=lidn;\n"
" \n"
" A_k0.s0=AlocalR[0][row];\n"
" A_k0.s1=AlocalR[1][row];\n"
" A_k0.s2=AlocalR[2][row];\n"
" A_k0.s3=AlocalR[3][row];\n"
" \n"
" #pragma unroll\n"
" for (ushort wn=0; wn<OPTN; wn++) {\n"
" B_k0[wn].s0=BlocalR[col][0];\n"
" B_k0[wn].s1=BlocalR[col][1];\n"
" B_k0[wn].s2=BlocalR[col][2];\n"
" B_k0[wn].s3=BlocalR[col][3];\n"
" sum[0][wn] += dot(A_k0,B_k0[wn]);\n"
" col += TPWN;\n"
" }\n"
" \n"
" #pragma unroll\n"
" for(ushort wm=1; wm<OPTM; wm++) {\n"
" row += TPWM;\n"
" A_k0.s0=AlocalR[0][row];\n"
" A_k0.s1=AlocalR[1][row];\n"
" A_k0.s2=AlocalR[2][row];\n"
" A_k0.s3=AlocalR[3][row];\n"
" for (ushort wn=0; wn<OPTN; wn++) {\n"
" sum[wm][wn] += dot(A_k0,B_k0[wn]);\n"
" }\n"
" }\n"
" }\n"
" {\n"
" int col=lidn;\n"
" for (ushort wn=0; wn<OPTN; wn++) {\n"
" B_k0[wn].s0=BlocalR[col][4];\n"
" B_k0[wn].s1=BlocalR[col][5];\n"
" B_k0[wn].s2=BlocalR[col][6];\n"
" B_k0[wn].s3=BlocalR[col][7];\n"
" col += TPWN;\n"
" }\n"
" ushort row=lidm;\n"
" for (ushort wm=0; wm<OPTM; wm++) {\n"
" A_k0.s0=AlocalR[4][row];\n"
" A_k0.s1=AlocalR[5][row];\n"
" A_k0.s2=AlocalR[6][row];\n"
" A_k0.s3=AlocalR[7][row];\n"
" for (ushort wn=0; wn<OPTN; wn++) {\n"
" sum[wm][wn] += dot(A_k0,B_k0[wn]);\n"
" }\n"
" row += TPWM;\n"
" }\n"
" }\n"
" \n"
" t++;\n"
" // Loop over the values of a single tile\n"
" // Load one work-group of A and B into local memory\n"
" for (ushort la=0; la<LPTA; la++) {\n"
" ushort id=la*TPWN*TPWM+lid;\n"
" ushort row=id % OPWM;\n"
" ushort col=id/OPWM;\n"
" ushort tiledIndex=CPWK*t+col;\n"
" #ifdef TRANSPOSE_A\n"
" // [K,M]\n"
" AlocalC[col][row]=A[tiledIndex*M+(offsetM+row)];\n"
" #else\n"
" // [M,K]\n"
" AlocalC[col][row]=A[(offsetM+row)*K+tiledIndex];\n"
" #endif\n"
" }\n"
" for (ushort la=0; la<LPTB; la++) {\n"
" ushort id=la*TPWN*TPWM+lid;\n"
" ushort row=id % OPWN;\n"
" ushort col=id/OPWN;\n"
" ushort tiledIndex=CPWK*t+col;\n"
" #ifdef TRANSPOSE_B\n"
" // [N,K]\n"
" BlocalC[row][col]=B[(offsetN+row)*K+tiledIndex];\n"
" #else\n"
" // [K,N]\n"
" BlocalC[row][col]=B[tiledIndex*N+offsetN+row];\n"
" #endif\n"
" }\n"
" // Synchronise to make sure the tile is loaded\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" \n"
" // Perform the computation\n"
" {\n"
" ushort row=lidm;\n"
" ushort col=lidn;\n"
" \n"
" A_k0.s0=AlocalC[0][row];\n"
" A_k0.s1=AlocalC[1][row];\n"
" A_k0.s2=AlocalC[2][row];\n"
" A_k0.s3=AlocalC[3][row];\n"
" \n"
" #pragma unroll\n"
" for (ushort wn=0; wn<OPTN; wn++) {\n"
" B_k0[wn].s0=BlocalC[col][0];\n"
" B_k0[wn].s1=BlocalC[col][1];\n"
" B_k0[wn].s2=BlocalC[col][2];\n"
" B_k0[wn].s3=BlocalC[col][3];\n"
" sum[0][wn] += dot(A_k0,B_k0[wn]);\n"
" col += TPWN;\n"
" }\n"
" \n"
" #pragma unroll\n"
" for(ushort wm=1; wm<OPTM; wm++) {\n"
" row += TPWM;\n"
" A_k0.s0=AlocalC[0][row];\n"
" A_k0.s1=AlocalC[1][row];\n"
" A_k0.s2=AlocalC[2][row];\n"
" A_k0.s3=AlocalC[3][row];\n"
" for (ushort wn=0; wn<OPTN; wn++) {\n"
" sum[wm][wn] += dot(A_k0,B_k0[wn]);\n"
" }\n"
" }\n"
" }\n"
" {\n"
" ushort col=lidn;\n"
" for (ushort wn=0; wn<OPTN; wn++) {\n"
" B_k0[wn].s0=BlocalC[col][4];\n"
" B_k0[wn].s1=BlocalC[col][5];\n"
" B_k0[wn].s2=BlocalC[col][6];\n"
" B_k0[wn].s3=BlocalC[col][7];\n"
" col += TPWN;\n"
" }\n"
" ushort row=lidm;\n"
" for (ushort wm=0; wm<OPTM; wm++) {\n"
" A_k0.s0=AlocalC[4][row];\n"
" A_k0.s1=AlocalC[5][row];\n"
" A_k0.s2=AlocalC[6][row];\n"
" A_k0.s3=AlocalC[7][row];\n"
" for (ushort wn=0; wn<OPTN; wn++) {\n"
" sum[wm][wn] += dot(A_k0,B_k0[wn]);\n"
" }\n"
" row += TPWM;\n"
" }\n"
" }\n"
" }\n"
" // Store the final results in C\n"
" for (ushort wm=0; wm<OPTM; wm++) {\n"
" ushort globalRow=offsetM+lidm+wm*TPWM;\n"
" for (ushort wn=0; wn<OPTN; wn++) {\n"
" ushort globalCol=offsetN+lidn+wn*TPWN;\n"
" #ifdef BIAS\n"
" sum[wm][wn] += bias[globalCol];\n"
" #endif\n"
" C[globalRow*N+globalCol]=sum[wm][wn];\n"
" }\n"
" }\n"
"}\n"
;
#endif
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* conv_2d_int_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0,__private const int global_size_dim1,\n"
"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n"
"#define MOD_NUM 15\n"
"#ifdef INPUT_CHANNEL_LEAVE\n"
" #define PADZEROSVEC(k, channel, data0, data1, data2, data3) "" data0 = (k << 2) < channel ? data0 : 0; "" data1 = (k << 2) + 1 < channel ? data1 : 0; "" data2 = (k << 2) + 2 < channel ? data2 : 0; "" data3=(k << 2)+3<channel ? data3 : 0;\n"
"#else\n"
" #define PADZEROSVEC(k,channel,data0,data1,data2,data3)\n"
"#endif\n"
"__kernel\n"
"void conv_2d_int_c4h1w1(GLOBAL_SIZE_2_DIMS\n"
" __global const FLOAT *input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *weight,\n"
"#else\n"
" __global const uchar *weight,\n"
"#endif\n"
" __global const float *dequantScaleOffset,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT *output,\n"
" __private const int2 in_hw,\n"
" __private const int inChannel,\n"
" __private const int in_c_blocks,\n"
" __private const int batch,\n"
" __private const int2 out_hw,\n"
" __private const int2 filter_hw,\n"
" __private const int2 stride_hw,\n"
" __private const int2 pad_hw,\n"
" __private const int2 dilate_hw,\n"
" __private const int out_w_blocks,\n"
" __private const int out_c_blocks,\n"
" __private const int out_h_blocks,\n"
" __private const int blockDim) {\n"
" const int out_c_w_idx=get_global_id(0); //c/4 w\n"
" const int out_b_h_idx=get_global_id(1); //b h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int out_c_idx=out_c_w_idx/out_hw.y;\n"
" const int out_w_idx=out_c_w_idx % out_hw.y;\n"
" const int out_b_idx=out_b_h_idx/out_hw.x;//equal to in_b_idx\n"
" const int out_h_idx=out_b_h_idx % out_hw.x;\n"
" \n"
" COMPUTE_FLOAT4 out0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias));\n"
" const int in_w_idx_base=mad24(out_w_idx,stride_hw.y,-pad_hw.y);\n"
" const int in_h_idx_base=mad24(out_h_idx,stride_hw.x,-pad_hw.x);\n"
" \n"
" const int kw_start=select(0,(-in_w_idx_base+dilate_hw.y-1)/dilate_hw.y,in_w_idx_base<0);\n"
" const int kh_start=select(0,(-in_h_idx_base+dilate_hw.x-1)/dilate_hw.x,in_h_idx_base<0);\n"
" const int in_w_idx_start=mad24(kw_start,dilate_hw.y,in_w_idx_base);\n"
" const int in_w_idx_end=min(mad24(filter_hw.y,dilate_hw.y,in_w_idx_base),in_hw.y);\n"
" \n"
" const int in_h_idx_start=mad24(kh_start,dilate_hw.x,in_h_idx_base);\n"
" const int in_h_idx_end=min(mad24(filter_hw.x,dilate_hw.x,in_h_idx_base),in_hw.x);\n"
" \n"
" const int weight_oc_offset=out_c_blocks*filter_hw.x*filter_hw.y*4;\n"
" for(ushort in_c_idx=0; in_c_idx<in_c_blocks; in_c_idx++) {\n"
" int kindex=(in_c_idx*4)/blockDim*out_c_blocks*8;\n"
" COMPUTE_FLOAT8 ScaleOffset=CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale=(COMPUTE_FLOAT4)(ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6);\n"
" COMPUTE_FLOAT4 offset=(COMPUTE_FLOAT4)(ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7);\n"
" //weights NC4HW4 [1,4*icC4,ocC4*kh*kw,1] xic4\n"
" //index: [0,4*in_c_idx,out_c_idx*kh*kw+kh_start*kw+kw_start,0]\n"
" int weight_offset=((((4*in_c_idx+0)* out_c_blocks+out_c_idx) *filter_hw.x+kh_start)*filter_hw.y+kw_start)*4;\n"
" for(int iy=in_h_idx_start; iy<in_h_idx_end; iy += dilate_hw.x) {\n"
" for(int ix=in_w_idx_start; ix<in_w_idx_end; ix += dilate_hw.y) {\n"
" int inp_offset=(((out_b_idx+in_c_idx*batch)*in_hw.x+iy)*in_hw.y+ix)*4;\n"
" COMPUTE_FLOAT4 in0=CONVERT_COMPUTE_FLOAT4(vload4(0,input+inp_offset));\n"
" \n"
" const int filter_w_inc=(ix-in_w_idx_start)/dilate_hw.y;\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" char4 charWeight0=vload4(filter_w_inc,weight+weight_offset);\n"
" char4 charWeight1=vload4(filter_w_inc,weight+weight_offset+weight_oc_offset);\n"
" char4 charWeight2=vload4(filter_w_inc,weight+weight_offset+weight_oc_offset*2);\n"
" char4 charWeight3=vload4(filter_w_inc,weight+weight_offset+weight_oc_offset*3);\n"
" COMPUTE_FLOAT4 weight0=CONVERT_COMPUTE_FLOAT4(charWeight0)*scale+offset;\n"
" COMPUTE_FLOAT4 weight1=CONVERT_COMPUTE_FLOAT4(charWeight1)*scale+offset;\n"
" COMPUTE_FLOAT4 weight2=CONVERT_COMPUTE_FLOAT4(charWeight2)*scale+offset;\n"
" COMPUTE_FLOAT4 weight3=CONVERT_COMPUTE_FLOAT4(charWeight3)*scale+offset;\n"
"#else\n"
" uchar2 charWeightInt40=vload2(filter_w_inc,weight+weight_offset/2);\n"
" uchar2 charWeightInt41=vload2(filter_w_inc,weight+weight_offset/2+weight_oc_offset/2);\n"
" uchar2 charWeightInt42=vload2(filter_w_inc,weight+weight_offset/2+weight_oc_offset*2/2);\n"
" uchar2 charWeightInt43=vload2(filter_w_inc,weight+weight_offset/2+weight_oc_offset*3/2);\n"
" char4 charWeight0=(char4)(0,0,0,0);\n"
" char4 charWeight1=(char4)(0,0,0,0);\n"
" char4 charWeight2=(char4)(0,0,0,0);\n"
" char4 charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" COMPUTE_FLOAT4 weight0=CONVERT_COMPUTE_FLOAT4(charWeight0)*scale+offset;\n"
" COMPUTE_FLOAT4 weight1=CONVERT_COMPUTE_FLOAT4(charWeight1)*scale+offset;\n"
" COMPUTE_FLOAT4 weight2=CONVERT_COMPUTE_FLOAT4(charWeight2)*scale+offset;\n"
" COMPUTE_FLOAT4 weight3=CONVERT_COMPUTE_FLOAT4(charWeight3)*scale+offset;\n"
"#endif\n"
" PADZEROSVEC(in_c_idx,inChannel,weight0,weight1,weight2,weight3);\n"
" \n"
" out0=mad(in0.x,weight0,out0);\n"
" out0=mad(in0.y,weight1,out0);\n"
" out0=mad(in0.z,weight2,out0);\n"
" out0=mad(in0.w,weight3,out0);\n"
" }\n"
" weight_offset += 4*filter_hw.y;\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" const int out_offset=(((out_b_idx+out_c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" \n"
"}\n"
"__kernel\n"
"void conv_2d_int_c4h1w2(GLOBAL_SIZE_2_DIMS\n"
" __global const FLOAT *input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *weight,\n"
"#else\n"
" __global const uchar *weight,\n"
"#endif\n"
" __global const float *dequantScaleOffset,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT *output,\n"
" __private const int2 in_hw,\n"
" __private const int inChannel,\n"
" __private const int in_c_blocks,\n"
" __private const int batch,\n"
" __private const int2 out_hw,\n"
" __private const int2 filter_hw,\n"
" __private const int2 stride_hw,\n"
" __private const int2 pad_hw,\n"
" __private const int2 dilate_hw,\n"
" __private const int out_w_blocks,//generate width's num\n"
" __private const int out_c_blocks,\n"
" __private const int out_h_blocks,\n"
" __private const int blockDim) {\n"
" const int out_c_w_idx=get_global_id(0); //c/4 w\n"
" const int out_b_h_idx=get_global_id(1); //b h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int out_c_idx=out_c_w_idx/out_w_blocks;\n"
" const int out_w_idx=(out_c_w_idx % out_w_blocks) << 1;\n"
" const int out_b_idx=out_b_h_idx/out_hw.x;//equal to in_b_idx\n"
" const int out_h_idx=out_b_h_idx % out_hw.x;\n"
" \n"
" COMPUTE_FLOAT4 bias0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias));\n"
" COMPUTE_FLOAT4 out0=bias0;\n"
" COMPUTE_FLOAT4 out1=bias0;\n"
" \n"
" const int in_w0_idx_base=mad24(out_w_idx,stride_hw.y,-pad_hw.y);\n"
" const int in_w1_idx_base=in_w0_idx_base+stride_hw.y;\n"
" const int in_h_idx_base=mad24(out_h_idx,stride_hw.x,-pad_hw.x);\n"
" \n"
" const int kh_start=select(0,(-in_h_idx_base+dilate_hw.x-1)/dilate_hw.x,in_h_idx_base<0);\n"
" const int in_h_idx_start=mad24(kh_start,dilate_hw.x,in_h_idx_base);\n"
" const int in_h_idx_end=min(mad24(filter_hw.x,dilate_hw.x,in_h_idx_base),in_hw.x);\n"
" \n"
" const int weight_oc_offset=out_c_blocks*filter_hw.x*filter_hw.y*4;\n"
" for(ushort in_c_idx=0; in_c_idx<in_c_blocks; in_c_idx++) {\n"
" int kindex=(in_c_idx*4)/blockDim*out_c_blocks*8;\n"
" COMPUTE_FLOAT8 ScaleOffset=CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale=(COMPUTE_FLOAT4)(ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6);\n"
" COMPUTE_FLOAT4 offset=(COMPUTE_FLOAT4)(ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7);\n"
" //weights NC4HW4 [1,4*icC4,ocC4*kh*kw,1] xic4\n"
" //index: [0,4*in_c_idx,out_c_idx*kh*kw+kh_start*kw+kw_start,0]\n"
" int weight_offset=((((4*in_c_idx+0)* out_c_blocks+out_c_idx) *filter_hw.x+kh_start)*filter_hw.y+0)*4;\n"
" for(int iy=in_h_idx_start; iy<in_h_idx_end; iy += dilate_hw.x) {\n"
" const int inp_offset_base=(((out_b_idx+in_c_idx*batch)*in_hw.x+iy)*in_hw.y+0)*4;\n"
" for(int fw=0; fw<filter_hw.y; fw++) {\n"
" const int in_w0_idx=fw*dilate_hw.y+in_w0_idx_base;\n"
" const int in_w1_idx=fw*dilate_hw.y+in_w1_idx_base;\n"
" COMPUTE_FLOAT4 in0=CONVERT_COMPUTE_FLOAT4((in_w0_idx<0 || in_w0_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w0_idx,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in1=CONVERT_COMPUTE_FLOAT4((in_w1_idx<0 || in_w1_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w1_idx,input+inp_offset_base));\n"
" \n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" char4 charWeight0=vload4(0,weight+weight_offset);\n"
" char4 charWeight1=vload4(0,weight+weight_offset+weight_oc_offset);\n"
" char4 charWeight2=vload4(0,weight+weight_offset+weight_oc_offset*2);\n"
" char4 charWeight3=vload4(0,weight+weight_offset+weight_oc_offset*3);\n"
" COMPUTE_FLOAT4 weight0=CONVERT_COMPUTE_FLOAT4(charWeight0)*scale+offset;\n"
" COMPUTE_FLOAT4 weight1=CONVERT_COMPUTE_FLOAT4(charWeight1)*scale+offset;\n"
" COMPUTE_FLOAT4 weight2=CONVERT_COMPUTE_FLOAT4(charWeight2)*scale+offset;\n"
" COMPUTE_FLOAT4 weight3=CONVERT_COMPUTE_FLOAT4(charWeight3)*scale+offset;\n"
"#else\n"
" uchar2 charWeightInt40=vload2(0,weight+weight_offset/2);\n"
" uchar2 charWeightInt41=vload2(0,weight+weight_offset/2+weight_oc_offset/2);\n"
" uchar2 charWeightInt42=vload2(0,weight+weight_offset/2+weight_oc_offset*2/2);\n"
" uchar2 charWeightInt43=vload2(0,weight+weight_offset/2+weight_oc_offset*3/2);\n"
" char4 charWeight0=(char4)(0,0,0,0);\n"
" char4 charWeight1=(char4)(0,0,0,0);\n"
" char4 charWeight2=(char4)(0,0,0,0);\n"
" char4 charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" COMPUTE_FLOAT4 weight0=CONVERT_COMPUTE_FLOAT4(charWeight0)*scale+offset;\n"
" COMPUTE_FLOAT4 weight1=CONVERT_COMPUTE_FLOAT4(charWeight1)*scale+offset;\n"
" COMPUTE_FLOAT4 weight2=CONVERT_COMPUTE_FLOAT4(charWeight2)*scale+offset;\n"
" COMPUTE_FLOAT4 weight3=CONVERT_COMPUTE_FLOAT4(charWeight3)*scale+offset;\n"
"#endif\n"
" PADZEROSVEC(in_c_idx,inChannel,weight0,weight1,weight2,weight3);\n"
" \n"
" out0=mad(in0.x,weight0,out0);\n"
" out0=mad(in0.y,weight1,out0);\n"
" out0=mad(in0.z,weight2,out0);\n"
" out0=mad(in0.w,weight3,out0);\n"
" \n"
" out1=mad(in1.x,weight0,out1);\n"
" out1=mad(in1.y,weight1,out1);\n"
" out1=mad(in1.z,weight2,out1);\n"
" out1=mad(in1.w,weight3,out1);\n"
" \n"
" weight_offset += 4;\n"
" }\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(COMPUTE_FLOAT4)0);\n"
" out1=fmax(out1,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out1=clamp(out1,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" const int out_offset=(((out_b_idx+out_c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
"#ifdef BLOCK_LEAVE\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" if(out_w_idx+1 >= out_hw.y) return;\n"
" vstore4(CONVERT_FLOAT4(out1),1,output+out_offset);\n"
"#else\n"
" vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out0,out1)),0,output+out_offset);\n"
"#endif\n"
"}\n"
"__kernel\n"
"void conv_2d_int_c4h1w4(GLOBAL_SIZE_2_DIMS\n"
" __global const FLOAT *input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *weight,\n"
"#else\n"
" __global const uchar *weight,\n"
"#endif\n"
" __global const float *dequantScaleOffset,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT *output,\n"
" __private const int2 in_hw,\n"
" __private const int inChannel,\n"
" __private const int in_c_blocks,\n"
" __private const int batch,\n"
" __private const int2 out_hw,\n"
" __private const int2 filter_hw,\n"
" __private const int2 stride_hw,\n"
" __private const int2 pad_hw,\n"
" __private const int2 dilate_hw,\n"
" __private const int out_w_blocks,\n"
" __private const int out_c_blocks,\n"
" __private const int out_h_blocks,\n"
" __private const int blockDim) {\n"
" const int out_c_w_idx=get_global_id(0); //c/4 w\n"
" const int out_b_h_idx=get_global_id(1); //b h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int out_c_idx=out_c_w_idx/out_w_blocks;\n"
" const int out_w_idx=(out_c_w_idx % out_w_blocks) << 2;\n"
" const int out_b_idx=out_b_h_idx/out_hw.x;//equal to in_b_idx\n"
" const int out_h_idx=out_b_h_idx % out_hw.x;\n"
" COMPUTE_FLOAT4 bias0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias));\n"
" COMPUTE_FLOAT4 out0=bias0;\n"
" COMPUTE_FLOAT4 out1=bias0;\n"
" COMPUTE_FLOAT4 out2=bias0;\n"
" COMPUTE_FLOAT4 out3=bias0;\n"
" const int in_w0_idx_base=mad24(out_w_idx,stride_hw.y,-pad_hw.y);\n"
" const int in_w1_idx_base=in_w0_idx_base+stride_hw.y;\n"
" const int in_w2_idx_base=in_w1_idx_base+stride_hw.y;\n"
" const int in_w3_idx_base=in_w2_idx_base+stride_hw.y;\n"
" const int in_h_idx_base=mad24(out_h_idx,stride_hw.x,-pad_hw.x);\n"
" \n"
" const int kh_start=select(0,(-in_h_idx_base+dilate_hw.x-1)/dilate_hw.x,in_h_idx_base<0);\n"
" const int in_h_idx_start=mad24(kh_start,dilate_hw.x,in_h_idx_base);\n"
" const int in_h_idx_end=min(mad24(filter_hw.x,dilate_hw.x,in_h_idx_base),in_hw.x);\n"
" \n"
" const int weight_oc_offset=out_c_blocks*filter_hw.x*filter_hw.y*4;\n"
" for(ushort in_c_idx=0; in_c_idx<in_c_blocks; in_c_idx++) {\n"
" int kindex=(in_c_idx*4)/blockDim*out_c_blocks*8;\n"
" COMPUTE_FLOAT8 ScaleOffset=CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale=(COMPUTE_FLOAT4)(ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6);\n"
" COMPUTE_FLOAT4 offset=(COMPUTE_FLOAT4)(ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7);\n"
" //weights NC4HW4 [1,4*icC4,ocC4*kh*kw,1] xic4\n"
" //index: [0,4*in_c_idx,out_c_idx*kh*kw+kh_start*kw+kw_start,0]\n"
" int weight_offset=((((4*in_c_idx+0)* out_c_blocks+out_c_idx) *filter_hw.x+kh_start)*filter_hw.y+0)*4;\n"
" for(int iy=in_h_idx_start; iy<in_h_idx_end; iy += dilate_hw.x) {\n"
" const int inp_offset_base=(((out_b_idx+in_c_idx*batch)*in_hw.x+iy)*in_hw.y+0)*4;\n"
" for(int fw=0; fw<filter_hw.y; fw++) {\n"
" const int in_w0_idx=fw*dilate_hw.y+in_w0_idx_base;\n"
" const int in_w1_idx=fw*dilate_hw.y+in_w1_idx_base;\n"
" const int in_w2_idx=fw*dilate_hw.y+in_w2_idx_base;\n"
" const int in_w3_idx=fw*dilate_hw.y+in_w3_idx_base;\n"
" COMPUTE_FLOAT4 in0=CONVERT_COMPUTE_FLOAT4((in_w0_idx<0 || in_w0_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w0_idx,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in1=CONVERT_COMPUTE_FLOAT4((in_w1_idx<0 || in_w1_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w1_idx,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in2=CONVERT_COMPUTE_FLOAT4((in_w2_idx<0 || in_w2_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w2_idx,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in3=CONVERT_COMPUTE_FLOAT4((in_w3_idx<0 || in_w3_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w3_idx,input+inp_offset_base));\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" char4 charWeight0=vload4(0,weight+weight_offset);\n"
" char4 charWeight1=vload4(0,weight+weight_offset+weight_oc_offset);\n"
" char4 charWeight2=vload4(0,weight+weight_offset+weight_oc_offset*2);\n"
" char4 charWeight3=vload4(0,weight+weight_offset+weight_oc_offset*3);\n"
" COMPUTE_FLOAT4 weight0=CONVERT_COMPUTE_FLOAT4(charWeight0)*scale+offset;\n"
" COMPUTE_FLOAT4 weight1=CONVERT_COMPUTE_FLOAT4(charWeight1)*scale+offset;\n"
" COMPUTE_FLOAT4 weight2=CONVERT_COMPUTE_FLOAT4(charWeight2)*scale+offset;\n"
" COMPUTE_FLOAT4 weight3=CONVERT_COMPUTE_FLOAT4(charWeight3)*scale+offset;\n"
"#else\n"
" uchar2 charWeightInt40=vload2(0,weight+weight_offset/2);\n"
" uchar2 charWeightInt41=vload2(0,weight+weight_offset/2+weight_oc_offset/2);\n"
" uchar2 charWeightInt42=vload2(0,weight+weight_offset/2+weight_oc_offset*2/2);\n"
" uchar2 charWeightInt43=vload2(0,weight+weight_offset/2+weight_oc_offset*3/2);\n"
" char4 charWeight0=(char4)(0,0,0,0);\n"
" char4 charWeight1=(char4)(0,0,0,0);\n"
" char4 charWeight2=(char4)(0,0,0,0);\n"
" char4 charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" COMPUTE_FLOAT4 weight0=CONVERT_COMPUTE_FLOAT4(charWeight0)*scale+offset;\n"
" COMPUTE_FLOAT4 weight1=CONVERT_COMPUTE_FLOAT4(charWeight1)*scale+offset;\n"
" COMPUTE_FLOAT4 weight2=CONVERT_COMPUTE_FLOAT4(charWeight2)*scale+offset;\n"
" COMPUTE_FLOAT4 weight3=CONVERT_COMPUTE_FLOAT4(charWeight3)*scale+offset;\n"
"#endif\n"
" PADZEROSVEC(in_c_idx,inChannel,weight0,weight1,weight2,weight3);\n"
" out0=mad(in0.x,weight0,out0);\n"
" out0=mad(in0.y,weight1,out0);\n"
" out0=mad(in0.z,weight2,out0);\n"
" out0=mad(in0.w,weight3,out0);\n"
" \n"
" out1=mad(in1.x,weight0,out1);\n"
" out1=mad(in1.y,weight1,out1);\n"
" out1=mad(in1.z,weight2,out1);\n"
" out1=mad(in1.w,weight3,out1);\n"
" \n"
" out2=mad(in2.x,weight0,out2);\n"
" out2=mad(in2.y,weight1,out2);\n"
" out2=mad(in2.z,weight2,out2);\n"
" out2=mad(in2.w,weight3,out2);\n"
" \n"
" out3=mad(in3.x,weight0,out3);\n"
" out3=mad(in3.y,weight1,out3);\n"
" out3=mad(in3.z,weight2,out3);\n"
" out3=mad(in3.w,weight3,out3);\n"
" \n"
" weight_offset += 4;\n"
" }\n"
" }\n"
" }\n"
" \n"
"#ifdef RELU\n"
" out0=fmax(out0,(COMPUTE_FLOAT4)0);\n"
" out1=fmax(out1,(COMPUTE_FLOAT4)0);\n"
" out2=fmax(out2,(COMPUTE_FLOAT4)0);\n"
" out3=fmax(out3,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out1=clamp(out1,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out2=clamp(out2,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out3=clamp(out3,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" const int out_offset=(((out_b_idx+out_c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
"#ifdef BLOCK_LEAVE\n"
" const int remain=out_hw.y-out_w_idx;\n"
" if (remain >= 4) {\n"
" vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out0,out1,out2,out3)),0,output+out_offset);\n"
" }else if(remain == 3){\n"
" vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out0,out1)),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out2),2,output+out_offset);\n"
" }else if(remain == 2){\n"
" vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out0,out1)),0,output+out_offset);\n"
" }else if(remain == 1){\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" }\n"
"#else\n"
" vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out0,out1,out2,out3)),0,output+out_offset);\n"
"#endif\n"
"}\n"
"__kernel\n"
"void conv_2d_int_c4h4w1(GLOBAL_SIZE_2_DIMS\n"
" __global const FLOAT *input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *weight,\n"
"#else\n"
" __global const uchar *weight,\n"
"#endif\n"
" __global const float *dequantScaleOffset,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT *output,\n"
" __private const int2 in_hw,\n"
" __private const int inChannel,\n"
" __private const int in_c_blocks,\n"
" __private const int batch,\n"
" __private const int2 out_hw,\n"
" __private const int2 filter_hw,\n"
" __private const int2 stride_hw,\n"
" __private const int2 pad_hw,\n"
" __private const int2 dilate_hw,\n"
" __private const int out_w_blocks,\n"
" __private const int out_c_blocks,\n"
" __private const int out_h_blocks,\n"
" __private const int blockDim) {\n"
" const int out_c_w_idx=get_global_id(0); //c/4 w\n"
" const int out_b_h_idx=get_global_id(1); //b h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int out_c_idx=out_c_w_idx/out_w_blocks;\n"
" const int out_w_idx=out_c_w_idx % out_w_blocks;\n"
" const int out_b_idx=out_b_h_idx/out_h_blocks;//equal to in_b_idx\n"
" const int out_h_idx=(out_b_h_idx % out_h_blocks) << 2;\n"
" \n"
" COMPUTE_FLOAT4 bias0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias));\n"
" COMPUTE_FLOAT4 out0=bias0;\n"
" COMPUTE_FLOAT4 out1=bias0;\n"
" COMPUTE_FLOAT4 out2=bias0;\n"
" COMPUTE_FLOAT4 out3=bias0;\n"
" const int in_w_idx_base=mad24(out_w_idx,stride_hw.y,-pad_hw.y);\n"
" const int in_h0_idx_base=mad24(out_h_idx,stride_hw.x,-pad_hw.x);\n"
" const int in_h1_idx_base=in_h0_idx_base+stride_hw.x;\n"
" const int in_h2_idx_base=in_h1_idx_base+stride_hw.x;\n"
" const int in_h3_idx_base=in_h2_idx_base+stride_hw.x;\n"
" \n"
" const int kw_start=select(0,(-in_w_idx_base+dilate_hw.y-1)/dilate_hw.y,in_w_idx_base<0);\n"
" const int in_w_idx_start=mad24(kw_start,dilate_hw.y,in_w_idx_base);\n"
" const int in_w_idx_end=min(mad24(filter_hw.y,dilate_hw.y,in_w_idx_base),in_hw.y);\n"
" \n"
" const int weight_oc_offset=out_c_blocks*filter_hw.x*filter_hw.y*4;\n"
" const int in_hw_size=in_hw.x*in_hw.y;\n"
" for(ushort in_c_idx=0; in_c_idx<in_c_blocks; in_c_idx++) {\n"
" int kindex=(in_c_idx*4)/blockDim*out_c_blocks*8;\n"
" COMPUTE_FLOAT8 ScaleOffset=CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale=(COMPUTE_FLOAT4)(ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6);\n"
" COMPUTE_FLOAT4 offset=(COMPUTE_FLOAT4)(ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7);\n"
" //weights NC4HW4 [1,4*icC4,ocC4*kh*kw,1] xic4\n"
" //index: [0,4*in_c_idx,out_c_idx*kh*kw+kh_start*kw+kw_start,0]\n"
" const int inp_offset_base=(out_b_idx+in_c_idx*batch)*in_hw.x*in_hw.y*4;\n"
" for(int iy=0; iy<filter_hw.x; iy++) {\n"
" int weight_offset=((((4*in_c_idx+0)* out_c_blocks+out_c_idx) *filter_hw.x+iy)*filter_hw.y+kw_start)*4;\n"
" const int in_h0_idx=(iy*dilate_hw.x+in_h0_idx_base)*in_hw.y;\n"
" const int in_h1_idx=(iy*dilate_hw.x+in_h1_idx_base)*in_hw.y;\n"
" const int in_h2_idx=(iy*dilate_hw.x+in_h2_idx_base)*in_hw.y;\n"
" const int in_h3_idx=(iy*dilate_hw.x+in_h3_idx_base)*in_hw.y;\n"
" for(int fw=in_w_idx_start; fw<in_w_idx_end; fw += dilate_hw.y) {\n"
" COMPUTE_FLOAT4 in0=CONVERT_COMPUTE_FLOAT4((in_h0_idx<0 || in_h0_idx >= in_hw_size) ? (FLOAT4)0 : vload4(in_h0_idx+fw,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in1=CONVERT_COMPUTE_FLOAT4((in_h1_idx<0 || in_h1_idx >= in_hw_size) ? (FLOAT4)0 : vload4(in_h1_idx+fw,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in2=CONVERT_COMPUTE_FLOAT4((in_h2_idx<0 || in_h2_idx >= in_hw_size) ? (FLOAT4)0 : vload4(in_h2_idx+fw,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in3=CONVERT_COMPUTE_FLOAT4((in_h3_idx<0 || in_h3_idx >= in_hw_size) ? (FLOAT4)0 : vload4(in_h3_idx+fw,input+inp_offset_base));\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" char4 charWeight0=vload4(0,weight+weight_offset);\n"
" char4 charWeight1=vload4(0,weight+weight_offset+weight_oc_offset);\n"
" char4 charWeight2=vload4(0,weight+weight_offset+weight_oc_offset*2);\n"
" char4 charWeight3=vload4(0,weight+weight_offset+weight_oc_offset*3);\n"
" COMPUTE_FLOAT4 weight0=CONVERT_COMPUTE_FLOAT4(charWeight0)*scale+offset;\n"
" COMPUTE_FLOAT4 weight1=CONVERT_COMPUTE_FLOAT4(charWeight1)*scale+offset;\n"
" COMPUTE_FLOAT4 weight2=CONVERT_COMPUTE_FLOAT4(charWeight2)*scale+offset;\n"
" COMPUTE_FLOAT4 weight3=CONVERT_COMPUTE_FLOAT4(charWeight3)*scale+offset;\n"
"#else\n"
" uchar2 charWeightInt40=vload2(0,weight+weight_offset/2);\n"
" uchar2 charWeightInt41=vload2(0,weight+weight_offset/2+weight_oc_offset/2);\n"
" uchar2 charWeightInt42=vload2(0,weight+weight_offset/2+weight_oc_offset*2/2);\n"
" uchar2 charWeightInt43=vload2(0,weight+weight_offset/2+weight_oc_offset*3/2);\n"
" char4 charWeight0=(char4)(0,0,0,0);\n"
" char4 charWeight1=(char4)(0,0,0,0);\n"
" char4 charWeight2=(char4)(0,0,0,0);\n"
" char4 charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" COMPUTE_FLOAT4 weight0=CONVERT_COMPUTE_FLOAT4(charWeight0)*scale+offset;\n"
" COMPUTE_FLOAT4 weight1=CONVERT_COMPUTE_FLOAT4(charWeight1)*scale+offset;\n"
" COMPUTE_FLOAT4 weight2=CONVERT_COMPUTE_FLOAT4(charWeight2)*scale+offset;\n"
" COMPUTE_FLOAT4 weight3=CONVERT_COMPUTE_FLOAT4(charWeight3)*scale+offset;\n"
"#endif\n"
" PADZEROSVEC(in_c_idx,inChannel,weight0,weight1,weight2,weight3);\n"
" out0=mad(in0.x,weight0,out0);\n"
" out0=mad(in0.y,weight1,out0);\n"
" out0=mad(in0.z,weight2,out0);\n"
" out0=mad(in0.w,weight3,out0);\n"
" \n"
" out1=mad(in1.x,weight0,out1);\n"
" out1=mad(in1.y,weight1,out1);\n"
" out1=mad(in1.z,weight2,out1);\n"
" out1=mad(in1.w,weight3,out1);\n"
" \n"
" out2=mad(in2.x,weight0,out2);\n"
" out2=mad(in2.y,weight1,out2);\n"
" out2=mad(in2.z,weight2,out2);\n"
" out2=mad(in2.w,weight3,out2);\n"
" \n"
" out3=mad(in3.x,weight0,out3);\n"
" out3=mad(in3.y,weight1,out3);\n"
" out3=mad(in3.z,weight2,out3);\n"
" out3=mad(in3.w,weight3,out3);\n"
" \n"
" weight_offset += 4;\n"
" }\n"
" }\n"
" }\n"
" \n"
"#ifdef RELU\n"
" out0=fmax(out0,(COMPUTE_FLOAT4)0);\n"
" out1=fmax(out1,(COMPUTE_FLOAT4)0);\n"
" out2=fmax(out2,(COMPUTE_FLOAT4)0);\n"
" out3=fmax(out3,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out1=clamp(out1,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out2=clamp(out2,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out3=clamp(out3,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" const int out_offset=(((out_b_idx+out_c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
"#ifdef BLOCK_LEAVE\n"
" const int remain=out_hw.x-out_h_idx;\n"
" if(remain >= 4){\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out1),out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out2),2*out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out3),3*out_hw.y,output+out_offset);\n"
" }else if(remain == 3){\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out1),out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out2),2*out_hw.y,output+out_offset);\n"
" }else if(remain == 2){\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out1),out_hw.y,output+out_offset);\n"
" }else if(remain == 1){\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" }\n"
"#else\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out1),out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out2),2*out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out3),3*out_hw.y,output+out_offset);\n"
"#endif\n"
"}\n"
"__kernel\n"
"void conv_2d_int_c8h4w1(GLOBAL_SIZE_2_DIMS\n"
" __global const FLOAT *input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *weight,\n"
"#else\n"
" __global const uchar *weight,\n"
"#endif\n"
" __global const float *dequantScaleOffset,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT *output,\n"
" __private const int2 in_hw,\n"
" __private const int inChannel,\n"
" __private const int in_c_blocks,\n"
" __private const int batch,\n"
" __private const int2 out_hw,\n"
" __private const int2 filter_hw,\n"
" __private const int2 stride_hw,\n"
" __private const int2 pad_hw,\n"
" __private const int2 dilate_hw,\n"
" __private const int out_w_blocks,\n"
" __private const int out_c_blocks,\n"
" __private const int out_h_blocks,\n"
" __private const int blockDim) {\n"
" const int out_c_w_idx=get_global_id(0); //c/4 w\n"
" const int out_b_h_idx=get_global_id(1); //b h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int out_c_idx=(out_c_w_idx/out_w_blocks) << 1;\n"
" const int out_w_idx=out_c_w_idx % out_w_blocks;\n"
" const int out_b_idx=out_b_h_idx/out_h_blocks;//equal to in_b_idx\n"
" const int out_h_idx=(out_b_h_idx % out_h_blocks) << 2;\n"
" \n"
" COMPUTE_FLOAT4 bias0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias));\n"
" COMPUTE_FLOAT4 out0=bias0;\n"
" COMPUTE_FLOAT4 out1=bias0;\n"
" COMPUTE_FLOAT4 out2=bias0;\n"
" COMPUTE_FLOAT4 out3=bias0;\n"
" COMPUTE_FLOAT4 bias1=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx+1,bias));\n"
" COMPUTE_FLOAT4 out4=bias1;\n"
" COMPUTE_FLOAT4 out5=bias1;\n"
" COMPUTE_FLOAT4 out6=bias1;\n"
" COMPUTE_FLOAT4 out7=bias1;\n"
" const int in_w_idx_base=mad24(out_w_idx,stride_hw.y,-pad_hw.y);\n"
" const int in_h0_idx_base=mad24(out_h_idx,stride_hw.x,-pad_hw.x);\n"
" const int in_h1_idx_base=in_h0_idx_base+stride_hw.x;\n"
" const int in_h2_idx_base=in_h1_idx_base+stride_hw.x;\n"
" const int in_h3_idx_base=in_h2_idx_base+stride_hw.x;\n"
" \n"
" const int kw_start=select(0,(-in_w_idx_base+dilate_hw.y-1)/dilate_hw.y,in_w_idx_base<0);\n"
" const int in_w_idx_start=mad24(kw_start,dilate_hw.y,in_w_idx_base);\n"
" const int in_w_idx_end=min(mad24(filter_hw.y,dilate_hw.y,in_w_idx_base),in_hw.y);\n"
" \n"
" const int weight_oc_offset=filter_hw.x*filter_hw.y*4;\n"
" const int weight_ic_offset=out_c_blocks*weight_oc_offset;\n"
" const int in_hw_size=in_hw.x*in_hw.y;\n"
" for(ushort in_c_idx=0; in_c_idx<in_c_blocks; in_c_idx++) {\n"
" int kindex=(in_c_idx*4)/blockDim*out_c_blocks*8;\n"
" COMPUTE_FLOAT8 ScaleOffset0=CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT8 ScaleOffset1=CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx+1,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale0=(COMPUTE_FLOAT4)(ScaleOffset0.s0,ScaleOffset0.s2,ScaleOffset0.s4,ScaleOffset0.s6);\n"
" COMPUTE_FLOAT4 offset0=(COMPUTE_FLOAT4)(ScaleOffset0.s1,ScaleOffset0.s3,ScaleOffset0.s5,ScaleOffset0.s7);\n"
" COMPUTE_FLOAT4 scale1=(COMPUTE_FLOAT4)(ScaleOffset1.s0,ScaleOffset1.s2,ScaleOffset1.s4,ScaleOffset1.s6);\n"
" COMPUTE_FLOAT4 offset1=(COMPUTE_FLOAT4)(ScaleOffset1.s1,ScaleOffset1.s3,ScaleOffset1.s5,ScaleOffset1.s7);\n"
" //weights NC4HW4 [1,4*icC4,ocC4*kh*kw,1] xic4\n"
" //index: [0,4*in_c_idx,out_c_idx*kh*kw+kh_start*kw+kw_start,0]\n"
" const int inp_offset_base=(out_b_idx+in_c_idx*batch)*in_hw.x*in_hw.y*4;\n"
" for(int iy=0; iy<filter_hw.x; iy++) {\n"
" int weight_offset=((((4*in_c_idx+0)* out_c_blocks+out_c_idx) *filter_hw.x+iy)*filter_hw.y+kw_start)*4;\n"
" const int in_h0_idx=(iy*dilate_hw.x+in_h0_idx_base)*in_hw.y;\n"
" const int in_h1_idx=(iy*dilate_hw.x+in_h1_idx_base)*in_hw.y;\n"
" const int in_h2_idx=(iy*dilate_hw.x+in_h2_idx_base)*in_hw.y;\n"
" const int in_h3_idx=(iy*dilate_hw.x+in_h3_idx_base)*in_hw.y;\n"
" for(int fw=in_w_idx_start; fw<in_w_idx_end; fw += dilate_hw.y) {\n"
" COMPUTE_FLOAT4 in0=CONVERT_COMPUTE_FLOAT4((in_h0_idx<0 || in_h0_idx >= in_hw_size) ? (FLOAT4)0 : vload4(in_h0_idx+fw,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in1=CONVERT_COMPUTE_FLOAT4((in_h1_idx<0 || in_h1_idx >= in_hw_size) ? (FLOAT4)0 : vload4(in_h1_idx+fw,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in2=CONVERT_COMPUTE_FLOAT4((in_h2_idx<0 || in_h2_idx >= in_hw_size) ? (FLOAT4)0 : vload4(in_h2_idx+fw,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in3=CONVERT_COMPUTE_FLOAT4((in_h3_idx<0 || in_h3_idx >= in_hw_size) ? (FLOAT4)0 : vload4(in_h3_idx+fw,input+inp_offset_base));\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" char4 charWeight0=vload4(0,weight+weight_offset);\n"
" char4 charWeight1=vload4(0,weight+weight_offset+weight_ic_offset);\n"
" char4 charWeight2=vload4(0,weight+weight_offset+weight_ic_offset*2);\n"
" char4 charWeight3=vload4(0,weight+weight_offset+weight_ic_offset*3);\n"
" COMPUTE_FLOAT4 weight0=CONVERT_COMPUTE_FLOAT4(charWeight0)*scale0+offset0;\n"
" COMPUTE_FLOAT4 weight1=CONVERT_COMPUTE_FLOAT4(charWeight1)*scale0+offset0;\n"
" COMPUTE_FLOAT4 weight2=CONVERT_COMPUTE_FLOAT4(charWeight2)*scale0+offset0;\n"
" COMPUTE_FLOAT4 weight3=CONVERT_COMPUTE_FLOAT4(charWeight3)*scale0+offset0;\n"
"#else\n"
" uchar2 charWeightInt40=vload2(0,weight+weight_offset/2);\n"
" uchar2 charWeightInt41=vload2(0,weight+weight_offset/2+weight_ic_offset/2);\n"
" uchar2 charWeightInt42=vload2(0,weight+weight_offset/2+weight_ic_offset*2/2);\n"
" uchar2 charWeightInt43=vload2(0,weight+weight_offset/2+weight_ic_offset*3/2);\n"
" char4 charWeight0=(char4)(0,0,0,0);\n"
" char4 charWeight1=(char4)(0,0,0,0);\n"
" char4 charWeight2=(char4)(0,0,0,0);\n"
" char4 charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)- 8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" COMPUTE_FLOAT4 weight0=CONVERT_COMPUTE_FLOAT4(charWeight0)*scale0+offset0;\n"
" COMPUTE_FLOAT4 weight1=CONVERT_COMPUTE_FLOAT4(charWeight1)*scale0+offset0;\n"
" COMPUTE_FLOAT4 weight2=CONVERT_COMPUTE_FLOAT4(charWeight2)*scale0+offset0;\n"
" COMPUTE_FLOAT4 weight3=CONVERT_COMPUTE_FLOAT4(charWeight3)*scale0+offset0;\n"
"#endif\n"
" PADZEROSVEC(in_c_idx,inChannel,weight0,weight1,weight2,weight3);\n"
" out0=mad(in0.x,weight0,out0);\n"
" out0=mad(in0.y,weight1,out0);\n"
" out0=mad(in0.z,weight2,out0);\n"
" out0=mad(in0.w,weight3,out0);\n"
" \n"
" out1=mad(in1.x,weight0,out1);\n"
" out1=mad(in1.y,weight1,out1);\n"
" out1=mad(in1.z,weight2,out1);\n"
" out1=mad(in1.w,weight3,out1);\n"
" \n"
" out2=mad(in2.x,weight0,out2);\n"
" out2=mad(in2.y,weight1,out2);\n"
" out2=mad(in2.z,weight2,out2);\n"
" out2=mad(in2.w,weight3,out2);\n"
" \n"
" out3=mad(in3.x,weight0,out3);\n"
" out3=mad(in3.y,weight1,out3);\n"
" out3=mad(in3.z,weight2,out3);\n"
" out3=mad(in3.w,weight3,out3);\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" charWeight0=vload4(0,weight+weight_offset+weight_oc_offset);\n"
" charWeight1=vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset);\n"
" charWeight2=vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*2);\n"
" charWeight3=vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*3);\n"
" weight0=CONVERT_COMPUTE_FLOAT4(charWeight0)*scale1+offset1;\n"
" weight1=CONVERT_COMPUTE_FLOAT4(charWeight1)*scale1+offset1;\n"
" weight2=CONVERT_COMPUTE_FLOAT4(charWeight2)*scale1+offset1;\n"
" weight3=CONVERT_COMPUTE_FLOAT4(charWeight3)*scale1+offset1;\n"
"#else\n"
" charWeightInt40=vload2(0,weight+weight_offset/2+weight_oc_offset/2);\n"
" charWeightInt41=vload2(0,weight+weight_offset/2+weight_oc_offset/2+weight_ic_offset/2);\n"
" charWeightInt42=vload2(0,weight+weight_offset/2+weight_oc_offset/2+weight_ic_offset*2/2);\n"
" charWeightInt43=vload2(0,weight+weight_offset/2+weight_oc_offset/2+weight_ic_offset*3/2);\n"
" charWeight0=(char4)(0,0,0,0);\n"
" charWeight1=(char4)(0,0,0,0);\n"
" charWeight2=(char4)(0,0,0,0);\n"
" charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)- 8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)- 8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" weight0=CONVERT_COMPUTE_FLOAT4(charWeight0)*scale1+offset1;\n"
" weight1=CONVERT_COMPUTE_FLOAT4(charWeight1)*scale1+offset1;\n"
" weight2=CONVERT_COMPUTE_FLOAT4(charWeight2)*scale1+offset1;\n"
" weight3=CONVERT_COMPUTE_FLOAT4(charWeight3)*scale1+offset1;\n"
"#endif\n"
" PADZEROSVEC(in_c_idx,inChannel,weight0,weight1,weight2,weight3);\n"
" out4=mad(in0.x,weight0,out4);\n"
" out4=mad(in0.y,weight1,out4);\n"
" out4=mad(in0.z,weight2,out4);\n"
" out4=mad(in0.w,weight3,out4);\n"
" \n"
" out5=mad(in1.x,weight0,out5);\n"
" out5=mad(in1.y,weight1,out5);\n"
" out5=mad(in1.z,weight2,out5);\n"
" out5=mad(in1.w,weight3,out5);\n"
" \n"
" out6=mad(in2.x,weight0,out6);\n"
" out6=mad(in2.y,weight1,out6);\n"
" out6=mad(in2.z,weight2,out6);\n"
" out6=mad(in2.w,weight3,out6);\n"
" \n"
" out7=mad(in3.x,weight0,out7);\n"
" out7=mad(in3.y,weight1,out7);\n"
" out7=mad(in3.z,weight2,out7);\n"
" out7=mad(in3.w,weight3,out7);\n"
" \n"
" weight_offset += 4;\n"
" }\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(COMPUTE_FLOAT4)0);\n"
" out1=fmax(out1,(COMPUTE_FLOAT4)0);\n"
" out2=fmax(out2,(COMPUTE_FLOAT4)0);\n"
" out3=fmax(out3,(COMPUTE_FLOAT4)0);\n"
" out4=fmax(out4,(COMPUTE_FLOAT4)0);\n"
" out5=fmax(out5,(COMPUTE_FLOAT4)0);\n"
" out6=fmax(out6,(COMPUTE_FLOAT4)0);\n"
" out7=fmax(out7,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out1=clamp(out1,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out2=clamp(out2,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out3=clamp(out3,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out4=clamp(out4,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out5=clamp(out5,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out6=clamp(out6,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out7=clamp(out7,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" int out_offset=(((out_b_idx+out_c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
"#ifdef BLOCK_LEAVE\n"
" const int remain=out_hw.x-out_h_idx;\n"
" if(remain >= 4){\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out1),out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out2),2*out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out3),3*out_hw.y,output+out_offset);\n"
" }else if(remain == 3){\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out1),out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out2),2*out_hw.y,output+out_offset);\n"
" }else if(remain == 2){\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out1),out_hw.y,output+out_offset);\n"
" }else if(remain == 1){\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" }\n"
"#ifdef CHANNEL_LEAVE\n"
" if(out_c_idx+1 >= out_c_blocks){\n"
" return;\n"
" }\n"
"#endif\n"
" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
" if(remain >= 4){\n"
" vstore4(CONVERT_FLOAT4(out4),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out5),out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out6),2*out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out7),3*out_hw.y,output+out_offset);\n"
" }else if(remain == 3){\n"
" vstore4(CONVERT_FLOAT4(out4),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out5),out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out6),2*out_hw.y,output+out_offset);\n"
" }else if(remain == 2){\n"
" vstore4(CONVERT_FLOAT4(out4),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out5),out_hw.y,output+out_offset);\n"
" }else if(remain == 1){\n"
" vstore4(CONVERT_FLOAT4(out4),0,output+out_offset);\n"
" }\n"
"#else\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out1),out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out2),2*out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out3),3*out_hw.y,output+out_offset);\n"
"#ifdef CHANNEL_LEAVE\n"
" if(out_c_idx+1 >= out_c_blocks){\n"
" return;\n"
" }\n"
"#endif\n"
" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
" vstore4(CONVERT_FLOAT4(out4),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out5),out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out6),2*out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out7),3*out_hw.y,output+out_offset);\n"
"#endif\n"
"}\n"
"__kernel\n"
"void conv_2d_int_c8h2w1(GLOBAL_SIZE_2_DIMS\n"
" __global const FLOAT *input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *weight,\n"
"#else\n"
" __global const uchar *weight,\n"
"#endif\n"
" __global const float *dequantScaleOffset,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT *output,\n"
" __private const int2 in_hw,\n"
" __private const int inChannel,\n"
" __private const int in_c_blocks,\n"
" __private const int batch,\n"
" __private const int2 out_hw,\n"
" __private const int2 filter_hw,\n"
" __private const int2 stride_hw,\n"
" __private const int2 pad_hw,\n"
" __private const int2 dilate_hw,\n"
" __private const int out_w_blocks,\n"
" __private const int out_c_blocks,\n"
" __private const int out_h_blocks,\n"
" __private const int blockDim) {\n"
" const int out_c_w_idx=get_global_id(0); //c/4 w\n"
" const int out_b_h_idx=get_global_id(1); //b h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int out_c_idx=(out_c_w_idx/out_w_blocks) << 1;\n"
" const int out_w_idx=out_c_w_idx % out_w_blocks;\n"
" const int out_b_idx=out_b_h_idx/out_h_blocks;//equal to in_b_idx\n"
" const int out_h_idx=(out_b_h_idx % out_h_blocks) << 1;\n"
" COMPUTE_FLOAT4 bias0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias));\n"
" COMPUTE_FLOAT4 out0=bias0;\n"
" COMPUTE_FLOAT4 out1=bias0;\n"
" COMPUTE_FLOAT4 bias1=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx+1,bias));\n"
" COMPUTE_FLOAT4 out2=bias1;\n"
" COMPUTE_FLOAT4 out3=bias1;\n"
" const int in_w_idx_base=mad24(out_w_idx,stride_hw.y,-pad_hw.y);\n"
" const int in_h0_idx_base=mad24(out_h_idx,stride_hw.x,-pad_hw.x);\n"
" const int in_h1_idx_base=in_h0_idx_base+stride_hw.x;\n"
" \n"
" const int kw_start=select(0,(-in_w_idx_base+dilate_hw.y-1)/dilate_hw.y,in_w_idx_base<0);\n"
" const int in_w_idx_start=mad24(kw_start,dilate_hw.y,in_w_idx_base);\n"
" const int in_w_idx_end=min(mad24(filter_hw.y,dilate_hw.y,in_w_idx_base),in_hw.y);\n"
" \n"
" const int weight_oc_offset=filter_hw.x*filter_hw.y*4;\n"
" const int weight_ic_offset=out_c_blocks*weight_oc_offset;\n"
" const int in_hw_size=in_hw.x*in_hw.y;\n"
" // weight: [ic/4,oc,4],loop: ic/4\n"
" for(ushort in_c_idx=0; in_c_idx<in_c_blocks; in_c_idx++) {\n"
" int kindex=(in_c_idx*4)/blockDim*out_c_blocks*8;\n"
" COMPUTE_FLOAT8 ScaleOffset0=CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT8 ScaleOffset1=CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx+1,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale0=(COMPUTE_FLOAT4)(ScaleOffset0.s0,ScaleOffset0.s2,ScaleOffset0.s4,ScaleOffset0.s6);\n"
" COMPUTE_FLOAT4 offset0=(COMPUTE_FLOAT4)(ScaleOffset0.s1,ScaleOffset0.s3,ScaleOffset0.s5,ScaleOffset0.s7);\n"
" COMPUTE_FLOAT4 scale1=(COMPUTE_FLOAT4)(ScaleOffset1.s0,ScaleOffset1.s2,ScaleOffset1.s4,ScaleOffset1.s6);\n"
" COMPUTE_FLOAT4 offset1=(COMPUTE_FLOAT4)(ScaleOffset1.s1,ScaleOffset1.s3,ScaleOffset1.s5,ScaleOffset1.s7);\n"
" //weights NC4HW4 [1,4*icC4,ocC4*kh*kw,1] xic4\n"
" //index: [0,4*in_c_idx,out_c_idx*kh*kw+kh_start*kw+kw_start,0]\n"
" const int inp_offset_base=(out_b_idx+in_c_idx*batch)*in_hw.x*in_hw.y*4;\n"
" for(int iy=0; iy<filter_hw.x; iy++) {\n"
" int weight_offset=((((4*in_c_idx+0)* out_c_blocks+out_c_idx) *filter_hw.x+iy)*filter_hw.y+kw_start)*4;\n"
" const int in_h0_idx=(iy*dilate_hw.x+in_h0_idx_base)*in_hw.y;\n"
" const int in_h1_idx=(iy*dilate_hw.x+in_h1_idx_base)*in_hw.y;\n"
" for(int fw=in_w_idx_start; fw<in_w_idx_end; fw += dilate_hw.y) {\n"
" COMPUTE_FLOAT4 in0=CONVERT_COMPUTE_FLOAT4((in_h0_idx<0 || in_h0_idx >= in_hw_size) ? (FLOAT4)0 : vload4(in_h0_idx+fw,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in1=CONVERT_COMPUTE_FLOAT4((in_h1_idx<0 || in_h1_idx >= in_hw_size) ? (FLOAT4)0 : vload4(in_h1_idx+fw,input+inp_offset_base));\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" char4 charWeight0=vload4(0,weight+weight_offset);\n"
" char4 charWeight1=vload4(0,weight+weight_offset+weight_ic_offset);\n"
" char4 charWeight2=vload4(0,weight+weight_offset+weight_ic_offset*2);\n"
" char4 charWeight3=vload4(0,weight+weight_offset+weight_ic_offset*3);\n"
" COMPUTE_FLOAT4 weight0=CONVERT_COMPUTE_FLOAT4(charWeight0)*scale0+offset0;\n"
" COMPUTE_FLOAT4 weight1=CONVERT_COMPUTE_FLOAT4(charWeight1)*scale0+offset0;\n"
" COMPUTE_FLOAT4 weight2=CONVERT_COMPUTE_FLOAT4(charWeight2)*scale0+offset0;\n"
" COMPUTE_FLOAT4 weight3=CONVERT_COMPUTE_FLOAT4(charWeight3)*scale0+offset0;\n"
"#else\n"
" uchar2 charWeightInt40=vload2(0,weight+weight_offset/2);\n"
" uchar2 charWeightInt41=vload2(0,weight+weight_offset/2+weight_ic_offset/2);\n"
" uchar2 charWeightInt42=vload2(0,weight+weight_offset/2+weight_ic_offset*2/2);\n"
" uchar2 charWeightInt43=vload2(0,weight+weight_offset/2+weight_ic_offset*3/2);\n"
" char4 charWeight0=(char4)(0,0,0,0);\n"
" char4 charWeight1=(char4)(0,0,0,0);\n"
" char4 charWeight2=(char4)(0,0,0,0);\n"
" char4 charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" COMPUTE_FLOAT4 weight0=CONVERT_COMPUTE_FLOAT4(charWeight0)*scale0+offset0;\n"
" COMPUTE_FLOAT4 weight1=CONVERT_COMPUTE_FLOAT4(charWeight1)*scale0+offset0;\n"
" COMPUTE_FLOAT4 weight2=CONVERT_COMPUTE_FLOAT4(charWeight2)*scale0+offset0;\n"
" COMPUTE_FLOAT4 weight3=CONVERT_COMPUTE_FLOAT4(charWeight3)*scale0+offset0;\n"
"#endif\n"
" PADZEROSVEC(in_c_idx,inChannel,weight0,weight1,weight2,weight3);\n"
" out0=mad(in0.x,weight0,out0);\n"
" out0=mad(in0.y,weight1,out0);\n"
" out0=mad(in0.z,weight2,out0);\n"
" out0=mad(in0.w,weight3,out0);\n"
" \n"
" out1=mad(in1.x,weight0,out1);\n"
" out1=mad(in1.y,weight1,out1);\n"
" out1=mad(in1.z,weight2,out1);\n"
" out1=mad(in1.w,weight3,out1);\n"
" \n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" charWeight0=vload4(0,weight+weight_offset+weight_oc_offset);\n"
" charWeight1=vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset);\n"
" charWeight2=vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*2);\n"
" charWeight3=vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*3);\n"
" weight0=CONVERT_COMPUTE_FLOAT4(charWeight0)*scale1+offset1;\n"
" weight1=CONVERT_COMPUTE_FLOAT4(charWeight1)*scale1+offset1;\n"
" weight2=CONVERT_COMPUTE_FLOAT4(charWeight2)*scale1+offset1;\n"
" weight3=CONVERT_COMPUTE_FLOAT4(charWeight3)*scale1+offset1;\n"
"#else\n"
" charWeightInt40=vload2(0,weight+weight_offset/2+weight_oc_offset/2);\n"
" charWeightInt41=vload2(0,weight+weight_offset/2+weight_oc_offset/2+weight_ic_offset/2);\n"
" charWeightInt42=vload2(0,weight+weight_offset/2+weight_oc_offset/2+weight_ic_offset*2/2);\n"
" charWeightInt43=vload2(0,weight+weight_offset/2+weight_oc_offset/2+weight_ic_offset*3/2);\n"
" charWeight0=(char4)(0,0,0,0);\n"
" charWeight1=(char4)(0,0,0,0);\n"
" charWeight2=(char4)(0,0,0,0);\n"
" charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0& MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1& MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0& MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1& MOD_NUM)-8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0& MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1& MOD_NUM)-8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0& MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1& MOD_NUM)-8;\n"
" weight0=CONVERT_COMPUTE_FLOAT4(charWeight0)*scale1+offset1;\n"
" weight1=CONVERT_COMPUTE_FLOAT4(charWeight1)*scale1+offset1;\n"
" weight2=CONVERT_COMPUTE_FLOAT4(charWeight2)*scale1+offset1;\n"
" weight3=CONVERT_COMPUTE_FLOAT4(charWeight3)*scale1+offset1;\n"
"#endif\n"
" PADZEROSVEC(in_c_idx,inChannel,weight0,weight1,weight2,weight3);\n"
" out2=mad(in0.x,weight0,out2);\n"
" out2=mad(in0.y,weight1,out2);\n"
" out2=mad(in0.z,weight2,out2);\n"
" out2=mad(in0.w,weight3,out2);\n"
" \n"
" out3=mad(in1.x,weight0,out3);\n"
" out3=mad(in1.y,weight1,out3);\n"
" out3=mad(in1.z,weight2,out3);\n"
" out3=mad(in1.w,weight3,out3);\n"
" \n"
" weight_offset += 4;\n"
" }\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(COMPUTE_FLOAT4)0);\n"
" out1=fmax(out1,(COMPUTE_FLOAT4)0);\n"
" out2=fmax(out2,(COMPUTE_FLOAT4)0);\n"
" out3=fmax(out3,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out1=clamp(out1,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out2=clamp(out2,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out3=clamp(out3,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" int out_offset=(((out_b_idx+out_c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
"#ifdef BLOCK_LEAVE\n"
" const int remain=out_hw.x-out_h_idx;\n"
" if(remain >= 2){\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out1),out_hw.y,output+out_offset);\n"
" }else if(remain == 1){\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" }\n"
"#ifdef CHANNEL_LEAVE\n"
" if(out_c_idx+1 >= out_c_blocks){\n"
" return;\n"
" }\n"
"#endif\n"
" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
" if(remain >= 2){\n"
" vstore4(CONVERT_FLOAT4(out2),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out3),out_hw.y,output+out_offset);\n"
" }else if(remain == 1){\n"
" vstore4(CONVERT_FLOAT4(out2),0,output+out_offset);\n"
" }\n"
"#else\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out1),out_hw.y,output+out_offset);\n"
"#ifdef CHANNEL_LEAVE\n"
" if(out_c_idx+1 >= out_c_blocks){\n"
" return;\n"
" }\n"
"#endif\n"
" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
" vstore4(CONVERT_FLOAT4(out2),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out3),out_hw.y,output+out_offset);\n"
"#endif\n"
"}\n"
"__kernel\n"
"void conv_2d_int_c8h1w4(GLOBAL_SIZE_2_DIMS\n"
" __global const FLOAT *input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *weight,\n"
"#else\n"
" __global const uchar *weight,\n"
"#endif\n"
" __global const float *dequantScaleOffset,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT *output,\n"
" __private const int2 in_hw,\n"
" __private const int inChannel,\n"
" __private const int in_c_blocks,\n"
" __private const int batch,\n"
" __private const int2 out_hw,\n"
" __private const int2 filter_hw,\n"
" __private const int2 stride_hw,\n"
" __private const int2 pad_hw,\n"
" __private const int2 dilate_hw,\n"
" __private const int out_w_blocks,\n"
" __private const int out_c_blocks,\n"
" __private const int out_h_blocks,\n"
" __private const int blockDim) {\n"
" const int out_c_w_idx=get_global_id(0); //c/4 w\n"
" const int out_b_h_idx=get_global_id(1); //b h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int out_c_idx=(out_c_w_idx/out_w_blocks) << 1;\n"
" const int out_w_idx=(out_c_w_idx % out_w_blocks) << 2;\n"
" const int out_b_idx=out_b_h_idx/out_hw.x;//equal to in_b_idx\n"
" const int out_h_idx=out_b_h_idx % out_hw.x;\n"
" \n"
" COMPUTE_FLOAT4 bias0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias));\n"
" COMPUTE_FLOAT4 out0=bias0;\n"
" COMPUTE_FLOAT4 out1=bias0;\n"
" COMPUTE_FLOAT4 out2=bias0;\n"
" COMPUTE_FLOAT4 out3=bias0;\n"
" COMPUTE_FLOAT4 bias1=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx+1,bias));\n"
" COMPUTE_FLOAT4 out4=bias1;\n"
" COMPUTE_FLOAT4 out5=bias1;\n"
" COMPUTE_FLOAT4 out6=bias1;\n"
" COMPUTE_FLOAT4 out7=bias1;\n"
" const int in_w0_idx_base=mad24(out_w_idx,stride_hw.y,-pad_hw.y);\n"
" const int in_w1_idx_base=in_w0_idx_base+stride_hw.y;\n"
" const int in_w2_idx_base=in_w1_idx_base+stride_hw.y;\n"
" const int in_w3_idx_base=in_w2_idx_base+stride_hw.y;\n"
" const int in_h_idx_base=mad24(out_h_idx,stride_hw.x,-pad_hw.x);\n"
" \n"
" const int kh_start=select(0,(-in_h_idx_base+dilate_hw.x-1)/dilate_hw.x,in_h_idx_base<0);\n"
" const int in_h_idx_start=mad24(kh_start,dilate_hw.x,in_h_idx_base);\n"
" const int in_h_idx_end=min(mad24(filter_hw.x,dilate_hw.x,in_h_idx_base),in_hw.x);\n"
" \n"
" const int weight_oc_offset=filter_hw.x*filter_hw.y*4;\n"
" const int weight_ic_offset=out_c_blocks*weight_oc_offset;\n"
" for(ushort in_c_idx=0; in_c_idx<in_c_blocks; in_c_idx++) {\n"
" int kindex=(in_c_idx*4)/blockDim*out_c_blocks*8;\n"
" COMPUTE_FLOAT8 ScaleOffset0=CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT8 ScaleOffset1=CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx+1,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale0=(COMPUTE_FLOAT4)(ScaleOffset0.s0,ScaleOffset0.s2,ScaleOffset0.s4,ScaleOffset0.s6);\n"
" COMPUTE_FLOAT4 offset0=(COMPUTE_FLOAT4)(ScaleOffset0.s1,ScaleOffset0.s3,ScaleOffset0.s5,ScaleOffset0.s7);\n"
" COMPUTE_FLOAT4 scale1=(COMPUTE_FLOAT4)(ScaleOffset1.s0,ScaleOffset1.s2,ScaleOffset1.s4,ScaleOffset1.s6);\n"
" COMPUTE_FLOAT4 offset1=(COMPUTE_FLOAT4)(ScaleOffset1.s1,ScaleOffset1.s3,ScaleOffset1.s5,ScaleOffset1.s7);\n"
" //weights NC4HW4 [1,4*icC4,ocC4*kh*kw,1] xic4\n"
" //index: [0,4*in_c_idx,out_c_idx*kh*kw+kh_start*kw+kw_start,0]\n"
" int weight_offset=((((4*in_c_idx+0)* out_c_blocks+out_c_idx) *filter_hw.x+kh_start)*filter_hw.y+0)*4;\n"
" for(int iy=in_h_idx_start; iy<in_h_idx_end; iy += dilate_hw.x) {\n"
" const int inp_offset_base=(((out_b_idx+in_c_idx*batch)*in_hw.x+iy)*in_hw.y+0)*4;\n"
" for(int fw=0; fw<filter_hw.y; fw++) {\n"
" const int in_w0_idx=fw*dilate_hw.y+in_w0_idx_base;\n"
" const int in_w1_idx=fw*dilate_hw.y+in_w1_idx_base;\n"
" const int in_w2_idx=fw*dilate_hw.y+in_w2_idx_base;\n"
" const int in_w3_idx=fw*dilate_hw.y+in_w3_idx_base;\n"
" COMPUTE_FLOAT4 in0=CONVERT_COMPUTE_FLOAT4((in_w0_idx<0 || in_w0_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w0_idx,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in1=CONVERT_COMPUTE_FLOAT4((in_w1_idx<0 || in_w1_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w1_idx,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in2=CONVERT_COMPUTE_FLOAT4((in_w2_idx<0 || in_w2_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w2_idx,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in3=CONVERT_COMPUTE_FLOAT4((in_w3_idx<0 || in_w3_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w3_idx,input+inp_offset_base));\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" char4 charWeight0=vload4(0,weight+weight_offset);\n"
" char4 charWeight1=vload4(0,weight+weight_offset+weight_ic_offset);\n"
" char4 charWeight2=vload4(0,weight+weight_offset+weight_ic_offset*2);\n"
" char4 charWeight3=vload4(0,weight+weight_offset+weight_ic_offset*3);\n"
" COMPUTE_FLOAT4 weight0=CONVERT_COMPUTE_FLOAT4(charWeight0)*scale0+offset0;\n"
" COMPUTE_FLOAT4 weight1=CONVERT_COMPUTE_FLOAT4(charWeight1)*scale0+offset0;\n"
" COMPUTE_FLOAT4 weight2=CONVERT_COMPUTE_FLOAT4(charWeight2)*scale0+offset0;\n"
" COMPUTE_FLOAT4 weight3=CONVERT_COMPUTE_FLOAT4(charWeight3)*scale0+offset0;\n"
"#else\n"
" uchar2 charWeightInt40=vload2(0,weight+weight_offset/2);\n"
" uchar2 charWeightInt41=vload2(0,weight+weight_offset/2+weight_ic_offset/2);\n"
" uchar2 charWeightInt42=vload2(0,weight+weight_offset/2+weight_ic_offset*2/2);\n"
" uchar2 charWeightInt43=vload2(0,weight+weight_offset/2+weight_ic_offset*3/2);\n"
" char4 charWeight0=(char4)(0,0,0,0);\n"
" char4 charWeight1=(char4)(0,0,0,0);\n"
" char4 charWeight2=(char4)(0,0,0,0);\n"
" char4 charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" COMPUTE_FLOAT4 weight0=CONVERT_COMPUTE_FLOAT4(charWeight0)*scale0+offset0;\n"
" COMPUTE_FLOAT4 weight1=CONVERT_COMPUTE_FLOAT4(charWeight1)*scale0+offset0;\n"
" COMPUTE_FLOAT4 weight2=CONVERT_COMPUTE_FLOAT4(charWeight2)*scale0+offset0;\n"
" COMPUTE_FLOAT4 weight3=CONVERT_COMPUTE_FLOAT4(charWeight3)*scale0+offset0;\n"
"#endif\n"
" PADZEROSVEC(in_c_idx,inChannel,weight0,weight1,weight2,weight3);\n"
" out0=mad(in0.x,weight0,out0);\n"
" out0=mad(in0.y,weight1,out0);\n"
" out0=mad(in0.z,weight2,out0);\n"
" out0=mad(in0.w,weight3,out0);\n"
" \n"
" out1=mad(in1.x,weight0,out1);\n"
" out1=mad(in1.y,weight1,out1);\n"
" out1=mad(in1.z,weight2,out1);\n"
" out1=mad(in1.w,weight3,out1);\n"
" \n"
" out2=mad(in2.x,weight0,out2);\n"
" out2=mad(in2.y,weight1,out2);\n"
" out2=mad(in2.z,weight2,out2);\n"
" out2=mad(in2.w,weight3,out2);\n"
" \n"
" out3=mad(in3.x,weight0,out3);\n"
" out3=mad(in3.y,weight1,out3);\n"
" out3=mad(in3.z,weight2,out3);\n"
" out3=mad(in3.w,weight3,out3);\n"
" \n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" charWeight0=vload4(0,weight+weight_offset+weight_oc_offset);\n"
" charWeight1=vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset);\n"
" charWeight2=vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*2);\n"
" charWeight3=vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*3);\n"
" weight0=CONVERT_COMPUTE_FLOAT4(charWeight0)*scale1+offset1;\n"
" weight1=CONVERT_COMPUTE_FLOAT4(charWeight1)*scale1+offset1;\n"
" weight2=CONVERT_COMPUTE_FLOAT4(charWeight2)*scale1+offset1;\n"
" weight3=CONVERT_COMPUTE_FLOAT4(charWeight3)*scale1+offset1;\n"
"#else\n"
" charWeightInt40=vload2(0,weight+weight_offset/2+weight_oc_offset/2);\n"
" charWeightInt41=vload2(0,weight+weight_offset/2+weight_oc_offset/2+weight_ic_offset/2);\n"
" charWeightInt42=vload2(0,weight+weight_offset/2+weight_oc_offset/2+weight_ic_offset*2/2);\n"
" charWeightInt43=vload2(0,weight+weight_offset/2+weight_oc_offset/2+weight_ic_offset*3/2);\n"
" charWeight0=(char4)(0,0,0,0);\n"
" charWeight1=(char4)(0,0,0,0);\n"
" charWeight2=(char4)(0,0,0,0);\n"
" charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)- 8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" weight0=CONVERT_COMPUTE_FLOAT4(charWeight0)*scale1+offset1;\n"
" weight1=CONVERT_COMPUTE_FLOAT4(charWeight1)*scale1+offset1;\n"
" weight2=CONVERT_COMPUTE_FLOAT4(charWeight2)*scale1+offset1;\n"
" weight3=CONVERT_COMPUTE_FLOAT4(charWeight3)*scale1+offset1;\n"
"#endif\n"
" PADZEROSVEC(in_c_idx,inChannel,weight0,weight1,weight2,weight3);\n"
" \n"
" out4=mad(in0.x,weight0,out4);\n"
" out4=mad(in0.y,weight1,out4);\n"
" out4=mad(in0.z,weight2,out4);\n"
" out4=mad(in0.w,weight3,out4);\n"
" \n"
" out5=mad(in1.x,weight0,out5);\n"
" out5=mad(in1.y,weight1,out5);\n"
" out5=mad(in1.z,weight2,out5);\n"
" out5=mad(in1.w,weight3,out5);\n"
" \n"
" out6=mad(in2.x,weight0,out6);\n"
" out6=mad(in2.y,weight1,out6);\n"
" out6=mad(in2.z,weight2,out6);\n"
" out6=mad(in2.w,weight3,out6);\n"
" \n"
" out7=mad(in3.x,weight0,out7);\n"
" out7=mad(in3.y,weight1,out7);\n"
" out7=mad(in3.z,weight2,out7);\n"
" out7=mad(in3.w,weight3,out7);\n"
" \n"
" weight_offset += 4;\n"
" }\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(COMPUTE_FLOAT4)0);\n"
" out1=fmax(out1,(COMPUTE_FLOAT4)0);\n"
" out2=fmax(out2,(COMPUTE_FLOAT4)0);\n"
" out3=fmax(out3,(COMPUTE_FLOAT4)0);\n"
" out4=fmax(out4,(COMPUTE_FLOAT4)0);\n"
" out5=fmax(out5,(COMPUTE_FLOAT4)0);\n"
" out6=fmax(out6,(COMPUTE_FLOAT4)0);\n"
" out7=fmax(out7,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out1=clamp(out1,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out2=clamp(out2,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out3=clamp(out3,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out4=clamp(out4,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out5=clamp(out5,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out6=clamp(out6,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out7=clamp(out7,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" int out_offset=(((out_b_idx+out_c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
"#ifdef BLOCK_LEAVE\n"
" const int remain=out_hw.y-out_w_idx;\n"
" if(remain >= 4){\n"
" vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out0,out1,out2,out3)),0,output+out_offset);\n"
" }else if(remain == 3){\n"
" vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out0,out1)),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out2),2,output+out_offset);\n"
" }else if(remain == 2){\n"
" vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out0,out1)),0,output+out_offset);\n"
" }else if(remain == 1){\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" }\n"
"#ifdef CHANNEL_LEAVE\n"
" if(out_c_idx+1 >= out_c_blocks)return;\n"
"#endif\n"
" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
" if(remain >= 4){\n"
" vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out4,out5,out6,out7)),0,output+out_offset);\n"
" }else if(remain == 3){\n"
" vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out4,out5)),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out6),2,output+out_offset);\n"
" }else if(remain == 2){\n"
" vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out4,out5)),0,output+out_offset);\n"
" }else if(remain == 1){\n"
" vstore4(CONVERT_FLOAT4(out4),0,output+out_offset);\n"
" }\n"
"#else\n"
" vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out0,out1,out2,out3)),0,output+out_offset);\n"
"#ifdef CHANNEL_LEAVE\n"
" if(out_c_idx+1 >= out_c_blocks)return;\n"
"#endif\n"
" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
" vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out4,out5,out6,out7)),0,output+out_offset);\n"
"#endif\n"
"}\n"
;
#endif
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* interp_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"__kernel void nearest_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT* input,\n"
" __global FLOAT* output,\n"
" __private const float height_scale,\n"
" __private const float width_scale,\n"
" __private const float height_offset,\n"
" __private const float width_offset,\n"
" __private const int input_height,\n"
" __private const int input_width,\n"
" __private const int out_height,\n"
" __private const int out_width,\n"
" __private const int batch) {\n"
" const int output_channel_block_idx=get_global_id(0);\n"
" const int output_width_block_idx=get_global_id(1);\n"
" const int output_batch_height_block_idx=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(output_channel_block_idx,output_width_block_idx,output_batch_height_block_idx);\n"
" const int output_batch_idx=output_batch_height_block_idx/out_height;\n"
" const int output_height_idx=output_batch_height_block_idx % out_height;\n"
" const float in_h_idx=output_height_idx*height_scale+height_offset;\n"
" const float in_w_idx=output_width_block_idx*width_scale+width_offset;\n"
"#ifdef USE_ROUND\n"
" const int in_h_index=min(max(0,(int)floor(in_h_idx+0.499f)),input_height-1);\n"
" const int in_w_index=min(max(0,(int)floor(in_w_idx+0.499f)),input_width-1);\n"
"#else\n"
" const int in_h_index=min(max(0,(int)floor(in_h_idx)),input_height-1);\n"
" const int in_w_index=min(max(0,(int)floor(in_w_idx)),input_width-1);\n"
"#endif\n"
" const int inp_offset=((output_batch_idx+output_channel_block_idx*batch)*input_height+in_h_index)*input_width+in_w_index;\n"
" FLOAT4 value=vload4(inp_offset,input);\n"
" const int out_offset=((output_batch_idx+output_channel_block_idx*batch)*out_height+output_height_idx)*out_width+output_width_block_idx;\n"
" vstore4(value,out_offset,output);\n"
"}\n"
"__kernel void bilinear_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT* input,\n"
" __global FLOAT* output,\n"
" __private const float height_scale,\n"
" __private const float width_scale,\n"
" __private const float height_offset,\n"
" __private const float width_offset,\n"
" __private const int input_height,\n"
" __private const int input_width,\n"
" __private const int out_height,\n"
" __private const int out_width,\n"
" __private const int batch) {\n"
" const int output_channel_block_idx=get_global_id(0);\n"
" const int output_width_block_idx=get_global_id(1);\n"
" const int output_batch_height_block_idx=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(output_channel_block_idx,output_width_block_idx,output_batch_height_block_idx);\n"
" \n"
" const int output_batch_idx=output_batch_height_block_idx/out_height;\n"
" const int output_height_idx=output_batch_height_block_idx % out_height;\n"
" const float in_h_idx=output_height_idx*height_scale+height_offset;\n"
" const float in_w_idx=output_width_block_idx*width_scale+width_offset;\n"
" const int in_h0_index=min(max(0,(int)floor(in_h_idx)),input_height-1);\n"
" const int in_w0_index=min(max(0,(int)floor(in_w_idx)),input_width-1);\n"
" const int in_h1_index=min(max(0,(int)floor(in_h_idx)+1),input_height-1);\n"
" const int in_w1_index=min(max(0,(int)floor(in_w_idx)+1),input_width-1);\n"
" \n"
" float factor_w=(in_w_idx-(int)floor(in_w_idx));\n"
" float factor_h=(in_h_idx-(int)floor(in_h_idx));\n"
" \n"
" const int inp_offset_base=(output_batch_idx+output_channel_block_idx*batch)*input_height;\n"
" const int inp_offset_00=(inp_offset_base+in_h0_index)*input_width+in_w0_index;\n"
" const int inp_offset_01=(inp_offset_base+in_h0_index)*input_width+in_w1_index;\n"
" const int inp_offset_10=(inp_offset_base+in_h1_index)*input_width+in_w0_index;\n"
" const int inp_offset_11=(inp_offset_base+in_h1_index)*input_width+in_w1_index;\n"
" FLOAT4 value_00=vload4(inp_offset_00,input);\n"
" FLOAT4 value_01=vload4(inp_offset_01,input);\n"
" FLOAT4 value_10=vload4(inp_offset_10,input);\n"
" FLOAT4 value_11=vload4(inp_offset_11,input);\n"
" FLOAT4 value=CONVERT_FLOAT4((float4)((1.0-factor_w)*(1.0-factor_h))*convert_float4(value_00)+(float4)(factor_w*(1.0-factor_h))*convert_float4(value_01)+(float4)((1.0-factor_w)*factor_h)*convert_float4(value_10)+(float4)(factor_w*factor_h)*convert_float4(value_11));\n"
" \n"
" const int out_offset=((output_batch_idx+output_channel_block_idx*batch)*out_height+output_height_idx)*out_width+output_width_block_idx;\n"
" \n"
" vstore4(value,out_offset,output);\n"
"}\n"
"__kernel void nearest3D_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT* input,\n"
" __global FLOAT* output,\n"
" __private const float depth_scale,\n"
" __private const float height_scale,\n"
" __private const float width_scale,\n"
" __private const float depth_offset,\n"
" __private const float height_offset,\n"
" __private const float width_offset,\n"
" __private const int input_depth,\n"
" __private const int input_height,\n"
" __private const int input_width,\n"
" __private const int out_depth,\n"
" __private const int out_height,\n"
" __private const int out_width,\n"
" __private const int batch) {\n"
" const int output_channel_block_idx=get_global_id(0);\n"
" const int output_height_width_block_idx=get_global_id(1);\n"
" const int output_batch_depth_block_idx=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(output_channel_block_idx,output_height_width_block_idx,output_batch_depth_block_idx);\n"
" const int output_batch_idx=output_batch_depth_block_idx/out_depth;\n"
" const int output_depth_idx=output_batch_depth_block_idx % out_depth;\n"
" const int output_height_idx=output_height_width_block_idx/out_height;\n"
" const int output_width_idx=output_height_width_block_idx % out_height;\n"
" const float in_d_idx=output_depth_idx*depth_scale+depth_offset;\n"
" const float in_h_idx=output_height_idx*height_scale+height_offset;\n"
" const float in_w_idx=output_width_idx*width_scale+width_offset;\n"
" const int in_d_index=min(max(0,(int)floor(in_d_idx)),input_depth-1);\n"
" const int in_h_index=min(max(0,(int)floor(in_h_idx)),input_height-1);\n"
" const int in_w_index=min(max(0,(int)floor(in_w_idx)),input_width-1);\n"
" const int inp_offset=(((output_batch_idx+output_channel_block_idx*batch)\n"
"*input_depth+in_d_index)*input_height+in_h_index)*input_width+in_w_index;\n"
" const int out_offset=(((output_batch_idx+output_channel_block_idx*batch)\n"
"*out_depth+output_depth_idx)*out_height+output_height_idx)*out_width+output_width_idx;\n"
" FLOAT4 value=vload4(inp_offset,input);\n"
" vstore4(value,out_offset,output);\n"
"}\n"
;
#endif
const char* scale = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void scale(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,__read_only image2d_t scale,\n"
"#ifdef HAS_BIAS\n"
" __read_only image2d_t bias,/* cout%4*cout/4 */\n"
"#endif\n"
" __write_only image2d_t output) {\n"
" const int channel_block_idx=get_global_id(0);\n"
" const int w=get_global_id(1);\n"
" const int hb=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(channel_block_idx,w,hb);\n"
" const int width=global_size_dim1;\n"
" const int pos=mad24(channel_block_idx,width,w);\n"
" FLOAT4 in=RI_F(input,SAMPLER,(int2)(pos,hb));\n"
" FLOAT4 scale_value=RI_F(scale,SAMPLER,(int2)(channel_block_idx,0));\n"
"#ifdef HAS_BIAS\n"
" FLOAT4 bias_value=RI_F(bias,SAMPLER,(int2)(channel_block_idx,0));\n"
" FLOAT4 out=in*scale_value+bias_value;\n"
"#else\n"
" FLOAT4 out=in*scale_value;\n"
"#endif\n"
" WI_F(output,(int2)(pos,hb),out);\n"
"}\n"
;
const char* softmax = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define EXP exp\n"
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void softmax_channel(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,__write_only image2d_t output,\n"
" __private const int remain_channels,__private const int4 shape // NCHW\n"
" ) {\n"
" const int x=get_global_id(0);\n"
" const int w=get_global_id(1);\n"
" const int bh=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(x,w,bh);\n"
"#if SOFTMAX_LOCAL_SIZE >= 4\n"
" int lid=get_local_id(0);\n"
" FLOAT4 local sum[SOFTMAX_LOCAL_SIZE];\n"
" FLOAT4 maxValue=(FLOAT4)-FLT_MAX;\n"
" for (int i=lid; i<shape.y-1; i+=SOFTMAX_LOCAL_SIZE) {\n"
" maxValue=fmax(maxValue,RI_F(input,SAMPLER,(int2)(w+i*shape.w,bh)));\n"
" }\n"
" sum[lid]=maxValue;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=SOFTMAX_LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=fmax(sum[lid],sum[lid+i]);\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" maxValue=sum[0];\n"
" maxValue.x=fmax(maxValue.x,maxValue.y);\n"
" maxValue.x=fmax(maxValue.x,maxValue.z);\n"
" maxValue.x=fmax(maxValue.x,maxValue.w);\n"
" FLOAT4 input_data=RI_F(input,SAMPLER,(int2)(w+(shape.y-1)*shape.w ,bh));\n"
" if (remain_channels == 0) {\n"
" maxValue.x=fmax(maxValue.x,input_data.x);\n"
" maxValue.x=fmax(maxValue.x,input_data.y);\n"
" maxValue.x=fmax(maxValue.x,input_data.z);\n"
" maxValue.x=fmax(maxValue.x,input_data.w);\n"
" } else if (remain_channels == 1) {\n"
" maxValue.x=fmax(maxValue.x,input_data.z);\n"
" maxValue.x=fmax(maxValue.x,input_data.y);\n"
" maxValue.x=fmax(maxValue.x,input_data.x);\n"
" } else if (remain_channels == 2) {\n"
" maxValue.x=fmax(maxValue.x,input_data.y);\n"
" maxValue.x=fmax(maxValue.x,input_data.x);\n"
" } else if (remain_channels == 3) {\n"
" maxValue.x=fmax(maxValue.x,input_data.x);\n"
" }\n"
" FLOAT4 sumValue=(FLOAT4)0;\n"
" for (int i=lid; i<shape.y-1; i+=SOFTMAX_LOCAL_SIZE) {\n"
" sumValue += exp(RI_F(input,SAMPLER,(int2)(w+i*shape.w,bh))-(FLOAT4)maxValue.x);\n"
" }\n"
" sum[lid]=sumValue;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=SOFTMAX_LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=sum[lid]+sum[lid+i];\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" sumValue=sum[0];\n"
" sumValue.x=sumValue.x+sumValue.y+sumValue.z+sumValue.w;\n"
" \n"
" \n"
" input_data -= maxValue.x;\n"
" if (remain_channels == 0) {\n"
" sumValue.x += exp(input_data.w);\n"
" sumValue.x += exp(input_data.z);\n"
" sumValue.x += exp(input_data.y);\n"
" sumValue.x += exp(input_data.x);\n"
" } else if (remain_channels == 1) {\n"
" sumValue.x += exp(input_data.z);\n"
" sumValue.x += exp(input_data.y);\n"
" sumValue.x += exp(input_data.x);\n"
" } else if (remain_channels == 2) {\n"
" sumValue.x += exp(input_data.y);\n"
" sumValue.x += exp(input_data.x);\n"
" } else if (remain_channels == 3) {\n"
" sumValue.x += exp(input_data.x);\n"
" }\n"
" for(int i=lid; i<shape.y; i+=SOFTMAX_LOCAL_SIZE){\n"
" FLOAT4 value=exp(RI_F(input,SAMPLER,(int2)(w+i*shape.w,bh))-maxValue.x)/sumValue.x;\n"
" WI_F(output,(int2)(w+i*shape.w,bh),value);\n"
" }\n"
"#else\n"
" FLOAT4 maxValue=(FLOAT4)-FLT_MAX;\n"
" for (int i=0; i<shape.y-1; i++) {\n"
" maxValue=fmax(maxValue,RI_F(input,SAMPLER,(int2)(w+i*shape.w,bh)));\n"
" }\n"
" \n"
" maxValue.x=fmax(maxValue.x,maxValue.y);\n"
" maxValue.x=fmax(maxValue.x,maxValue.z);\n"
" maxValue.x=fmax(maxValue.x,maxValue.w);\n"
" FLOAT4 input_data=RI_F(input,SAMPLER,(int2)(w+(shape.y-1)*shape.w ,bh));\n"
" if (remain_channels == 0) {\n"
" maxValue.x=fmax(maxValue.x,input_data.x);\n"
" maxValue.x=fmax(maxValue.x,input_data.y);\n"
" maxValue.x=fmax(maxValue.x,input_data.z);\n"
" maxValue.x=fmax(maxValue.x,input_data.w);\n"
" } else if (remain_channels == 1) {\n"
" maxValue.x=fmax(maxValue.x,input_data.z);\n"
" maxValue.x=fmax(maxValue.x,input_data.y);\n"
" maxValue.x=fmax(maxValue.x,input_data.x);\n"
" } else if (remain_channels == 2) {\n"
" maxValue.x=fmax(maxValue.x,input_data.y);\n"
" maxValue.x=fmax(maxValue.x,input_data.x);\n"
" } else if (remain_channels == 3) {\n"
" maxValue.x=fmax(maxValue.x,input_data.x);\n"
" }\n"
" FLOAT4 sumValue=(FLOAT4)0;\n"
" for (int i=0; i<shape.y-1; i++) {\n"
" sumValue += exp(RI_F(input,SAMPLER,(int2)(w+i*shape.w,bh))-(FLOAT4)maxValue.x);\n"
" }\n"
" sumValue.x=sumValue.x+sumValue.y+sumValue.z+sumValue.w;\n"
" input_data -= maxValue.x;\n"
" if (remain_channels == 0) {\n"
" sumValue.x += exp(input_data.w);\n"
" sumValue.x += exp(input_data.z);\n"
" sumValue.x += exp(input_data.y);\n"
" sumValue.x += exp(input_data.x);\n"
" } else if (remain_channels == 1) {\n"
" sumValue.x += exp(input_data.z);\n"
" sumValue.x += exp(input_data.y);\n"
" sumValue.x += exp(input_data.x);\n"
" } else if (remain_channels == 2) {\n"
" sumValue.x += exp(input_data.y);\n"
" sumValue.x += exp(input_data.x);\n"
" } else if (remain_channels == 3) {\n"
" sumValue.x += exp(input_data.x);\n"
" }\n"
" for(int i=0; i<shape.y; i++){\n"
" FLOAT4 value=exp(RI_F(input,SAMPLER,(int2)(w+i*shape.w,bh))-maxValue.x)/sumValue.x;\n"
" WI_F(output,(int2)(w+i*shape.w,bh),value);\n"
" }\n"
"#endif\n"
"}\n"
"__kernel void softmax_height(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,__write_only image2d_t output,\n"
" __private const int remain_channels,__private const int4 shape // NCHW\n"
" ) {\n"
" const int x=get_global_id(0);\n"
" const int wc=get_global_id(1);\n"
" const int b=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(x,wc,b);\n"
"#if SOFTMAX_LOCAL_SIZE >= 4\n"
" int lid=get_local_id(0);\n"
" FLOAT4 local sum[SOFTMAX_LOCAL_SIZE];\n"
" /*Compute Max */\n"
" FLOAT4 maxValue=(FLOAT4)(-FLT_MAX);\n"
" for (int i=lid; i<shape.z; i+=SOFTMAX_LOCAL_SIZE) {\n"
" maxValue=fmax(maxValue,RI_F(input,SAMPLER,(int2)(wc,b*shape.z+i)));\n"
" }\n"
" sum[lid]=maxValue;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=SOFTMAX_LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=fmax(sum[lid],sum[lid+i]);\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" maxValue=sum[0];\n"
" \n"
" /*Compute Exp Sum*/\n"
" FLOAT4 sumValue=(FLOAT4)0;\n"
" for (int i=lid; i<shape.z; i+=SOFTMAX_LOCAL_SIZE) {\n"
" sumValue += exp(RI_F(input,SAMPLER,(int2)(wc,b*shape.z+i))-maxValue);\n"
" }\n"
" sum[lid]=sumValue;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=SOFTMAX_LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=sum[lid]+sum[lid+i];\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" sumValue=sum[0];\n"
" \n"
" /*Compute Result */\n"
" for (int i=lid; i<shape.z; i+=SOFTMAX_LOCAL_SIZE) {\n"
" FLOAT4 value=exp(RI_F(input,SAMPLER,(int2)(wc,b*shape.z+i))-maxValue)/sumValue;\n"
" WI_F(output,(int2)(wc,b*shape.z+i),value);\n"
" }\n"
"#else\n"
" /*Compute Max */\n"
" FLOAT4 maxValue=(FLOAT4)(-FLT_MAX);\n"
" for (int i=0; i<shape.z; i++) {\n"
" maxValue=fmax(maxValue,RI_F(input,SAMPLER,(int2)(wc,b*shape.z+i)));\n"
" }\n"
" \n"
" /*Compute Exp Sum*/\n"
" FLOAT4 sumValue=(FLOAT4)0;\n"
" for (int i=0; i<shape.z; i++) {\n"
" sumValue += exp(RI_F(input,SAMPLER,(int2)(wc,b*shape.z+i))-maxValue);\n"
" }\n"
" \n"
" /*Compute Result */\n"
" for (int i=0; i<shape.z; i++) {\n"
" FLOAT4 value=exp(RI_F(input,SAMPLER,(int2)(wc,b*shape.z+i))-maxValue)/sumValue;\n"
" WI_F(output,(int2)(wc,b*shape.z+i),value);\n"
" }\n"
"#endif\n"
"}\n"
"__kernel void softmax_width(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,__write_only image2d_t output,\n"
" __private const int remain_channels,__private const int4 shape // NCHW\n"
" ) {\n"
" const int x=get_global_id(0);\n"
" const int c=get_global_id(1);\n"
" const int bh=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(x,c,bh);\n"
"#if SOFTMAX_LOCAL_SIZE >= 4\n"
" int lid=get_local_id(0);\n"
" FLOAT4 local sum[SOFTMAX_LOCAL_SIZE];\n"
" \n"
" /*Compute Max */\n"
" FLOAT4 maxValue=(FLOAT4)(-FLT_MAX);\n"
" for (int i=lid; i<shape.w; i+=SOFTMAX_LOCAL_SIZE) {\n"
" maxValue=fmax(maxValue,RI_F(input,SAMPLER,(int2)(c*shape.w+i,bh)));\n"
" }\n"
" sum[lid]=maxValue;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=SOFTMAX_LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=fmax(sum[lid],sum[lid+i]);\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" maxValue=sum[0];\n"
" \n"
" /*Compute Exp Sum*/\n"
" FLOAT4 sumValue=(FLOAT4)0;\n"
" for (int i=lid; i<shape.w; i+=SOFTMAX_LOCAL_SIZE) {\n"
" sumValue += exp(RI_F(input,SAMPLER,(int2)(c*shape.w+i,bh))-maxValue);\n"
" }\n"
" sum[lid]=sumValue;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=SOFTMAX_LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=sum[lid]+sum[lid+i];\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" sumValue=sum[0];\n"
" \n"
" /*Compute Result */\n"
" for (int i=lid; i<shape.w; i+=SOFTMAX_LOCAL_SIZE) {\n"
" FLOAT4 value=exp(RI_F(input,SAMPLER,(int2)(c*shape.w+i,bh))-maxValue)/sumValue;\n"
" WI_F(output,(int2)(c*shape.w+i,bh),value);\n"
" }\n"
"#else\n"
" /*Compute Max */\n"
" FLOAT4 maxValue=(FLOAT4)(-FLT_MAX);\n"
" for (int i=0; i<shape.w; i++) {\n"
" maxValue=fmax(maxValue,RI_F(input,SAMPLER,(int2)(c*shape.w+i,bh)));\n"
" }\n"
" /*Compute Exp Sum*/\n"
" FLOAT4 sumValue=(FLOAT4)0;\n"
" for (int i=0; i<shape.w; i++) {\n"
" sumValue += exp(RI_F(input,SAMPLER,(int2)(c*shape.w+i,bh))-maxValue);\n"
" }\n"
" \n"
" /*Compute Result */\n"
" for (int i=0; i<shape.w; i++) {\n"
" FLOAT4 value=exp(RI_F(input,SAMPLER,(int2)(c*shape.w+i,bh))-maxValue)/sumValue;\n"
" WI_F(output,(int2)(c*shape.w+i,bh),value);\n"
" }\n"
"#endif\n"
"}\n"
;
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* binary_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define PI 3.141592653589f\n"
"__kernel void binary_buf(__private int global_dim0,__private int global_dim1,\n"
" __global INPUT_TYPE* input0,__global INPUT_TYPE* input1,__global OUTPUT_TYPE* output,\n"
" __private const int size,\n"
" __private const int activationType) {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1));//NCHW,1\n"
" \n"
" if (pos.x<global_dim0 && pos.y<global_dim1) {\n"
" int offset=pos.x << 2;\n"
"#ifdef PACK_LEAVE\n"
" if(offset+3 >= size){\n"
" int remain=size-offset;\n"
" float4 in0,in1;\n"
" float* in0_ptr=(float*)&in0;\n"
" float* in1_ptr=(float*)&in1;\n"
" \n"
" for(int i=0; i<remain; ++i){\n"
" #ifdef A_SINGLE\n"
" in0_ptr[i]=(float)input0[0];\n"
" #else\n"
" in0_ptr[i]=(float)input0[offset+i];\n"
" #endif\n"
" \n"
" #ifdef B_SINGLE\n"
" in1_ptr[i]=(float)input1[0];\n"
" #else\n"
" in1_ptr[i]=(float)input1[offset+i];\n"
" #endif\n"
" }\n"
" float4 out=OPERATOR;\n"
" if(activationType == 1) {\n"
" out=fmax(out,(float4)0);\n"
" }\n"
" float* out_ptr=(float*)&out;\n"
" for(int i=0; i<remain; ++i){\n"
" output[offset+i]=(OUTPUT_TYPE)out_ptr[i];\n"
" }\n"
" }else {\n"
"#endif\n"
" #ifdef A_SINGLE\n"
" float data0=input0[0];\n"
" float4 in0=(float4)(data0,data0,data0,data0);\n"
" #else\n"
" float4 in0=convert_float4(vload4(0,input0+offset));\n"
" #endif\n"
" \n"
" #ifdef B_SINGLE\n"
" float data1=input1[0];\n"
" float4 in1=(float4)(data1,data1,data1,data1);\n"
" #else\n"
" float4 in1=convert_float4(vload4(0,input1+offset));\n"
" #endif\n"
" \n"
" float4 out=OPERATOR;\n"
" \n"
" if(activationType == 1) {\n"
" out=fmax(out,(float4)0);\n"
" }\n"
" vstore4(CONVERT_OUTPUT4(out),0,output+offset);\n"
"#ifdef PACK_LEAVE\n"
" }\n"
"#endif\n"
" }\n"
"}\n"
"__kernel void prelu_buf(__private int global_dim0,__private int global_dim1,\n"
" __global INPUT_TYPE* input0,__global INPUT_TYPE* input1,__global OUTPUT_TYPE* output,\n"
" __private const int4 shape\n"
" ) {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1));//NC4,HW\n"
" \n"
" if (pos.x<global_dim0 && pos.y<global_dim1) {\n"
" int b=pos.x/shape.w;\n"
" int c=pos.x % shape.w;\n"
" int offset=(b+c*shape.x)*(shape.y*shape.z)+pos.y;\n"
" float4 in0=convert_float4(vload4(offset,input0));\n"
" float4 in1=convert_float4(vload4(pos.x % shape.w,input1));\n"
" float4 out=OPERATOR;\n"
" vstore4(CONVERT_OUTPUT4(out),offset,output);\n"
" }\n"
"}\n"
;
#endif
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* raster_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0,__private const int global_size_dim1,\n"
"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"__kernel void buffer_set_zero(\n"
" GLOBAL_SIZE_2_DIMS\n"
" __global OUTPUT_TYPE *output\n"
" ) {\n"
" const int x=get_global_id(0);\n"
" const int y=get_global_id(1);\n"
" \n"
" DEAL_NON_UNIFORM_DIM2(x,y);\n"
" \n"
" output[y*global_size_dim0+x]=(OUTPUT_TYPE)(0.0f);\n"
"}\n"
"#define MNN_DATA_FORMAT_NCHW 0\n"
"#define MNN_DATA_FORMAT_NHWC 1\n"
"#define MNN_DATA_FORMAT_NC4HW4 2\n"
"__kernel void raster_direct_buffer(\n"
" GLOBAL_SIZE_3_DIMS\n"
" __private const int size_x,\n"
" __global INPUT_TYPE *input,\n"
" __private const int inputOffset,\n"
" __private const int combineSrcOffset,\n"
" __private const int inputStride0,\n"
" __private const int inputStride1,\n"
" __private const int inputStride2,\n"
" __private const int src_width,\n"
" __private const int src_height,\n"
" __private const int src_channel,\n"
" __private const int src_batch,\n"
" __global OUTPUT_TYPE *output,\n"
" __private const int outputOffset,\n"
" __private const int combineDstOffset,\n"
" __private const int outputStride0,\n"
" __private const int outputStride1,\n"
" __private const int outputStride2,\n"
" __private const int dst_width,\n"
" __private const int dst_height,\n"
" __private const int dst_channel,\n"
" __private const int dst_batch\n"
" ) {\n"
" const int idx=get_global_id(0);\n"
" const int y=get_global_id(1);\n"
" const int z=get_global_id(2);\n"
" \n"
" DEAL_NON_UNIFORM_DIM3(idx,y,z);\n"
" const int x=idx % size_x;\n"
" const int id=idx/size_x;\n"
" \n"
" int inputIndex=inputOffset+id*combineSrcOffset+z*inputStride0+y*inputStride1+x*inputStride2;\n"
" int outputIndex=outputOffset+id*combineDstOffset+z*outputStride0+y*outputStride1+x*outputStride2;\n"
"#if INPUT_FORMAT == MNN_DATA_FORMAT_NCHW\n"
" int inputIndexReal=inputIndex;\n"
"#elif INPUT_FORMAT == MNN_DATA_FORMAT_NHWC\n"
" int inputIndexReal=inputIndex;\n"
"#elif INPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4\n"
" int in_w=inputIndex % src_width; inputIndex /= src_width;\n"
" int in_h=inputIndex % src_height; inputIndex /= src_height;\n"
" int in_c=inputIndex % src_channel;\n"
" int in_b=inputIndex/src_channel;\n"
" int inputIndexReal=(((in_b+(in_c/4)*src_batch)*src_height+in_h)*src_width+in_w)*4+(in_c % 4);\n"
"#endif\n"
" \n"
"#if OUTPUT_FORMAT == MNN_DATA_FORMAT_NCHW\n"
" int outputIndexReal=outputIndex;\n"
"#elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NHWC\n"
" int outputIndexReal=outputIndex;\n"
"#elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4\n"
" int out_w=outputIndex % dst_width; outputIndex /= dst_width;\n"
" int out_h=outputIndex % dst_height; outputIndex /= dst_height;\n"
" int out_c=outputIndex % dst_channel;\n"
" int out_b=outputIndex/dst_channel;\n"
" int outputIndexReal=(((out_b+(out_c/4)*dst_batch)*dst_height+out_h)*dst_width+out_w)*4+(out_c % 4);\n"
"#endif\n"
" output[outputIndexReal]=(OUTPUT_TYPE)input[inputIndexReal];\n"
"}\n"
"__kernel void raster_nc4hw4_buffer(\n"
" GLOBAL_SIZE_3_DIMS\n"
" __global INPUT_TYPE *input,\n"
" __private const int inputOffset,\n"
" __private const int inputStride0,\n"
" __private const int inputStride1,\n"
" __private const int inputStride2,\n"
" __private const int inputHeight,\n"
" __private const int inputWidth,\n"
" __private const int inputChannel,\n"
" __global OUTPUT_TYPE *output,\n"
" __private const int outputOffset,\n"
" __private const int outputStride0,\n"
" __private const int outputStride1,\n"
" __private const int outputStride2,\n"
" __private const int outputHeight,\n"
" __private const int outputWidth,\n"
" __private const int outputChannel\n"
" ) {\n"
" const int x=get_global_id(0);\n"
" const int y=get_global_id(1);\n"
" const int z=get_global_id(2);\n"
" \n"
" DEAL_NON_UNIFORM_DIM3(x,y,z);\n"
" \n"
" int inputIndex=inputOffset+(z*inputStride0+y*inputStride1+x*inputStride2)*4;\n"
" int outputIndex=outputOffset+(z*outputStride0+y*outputStride1+x*outputStride2)*4;\n"
" \n"
" OUTPUT_TYPE4 values=CONVERT_OUTPUT4(vload4(0,(__global INPUT_TYPE *)(input+inputIndex)));\n"
" vstore4(values,0,(__global OUTPUT_TYPE *)(output+outputIndex));\n"
"}\n"
;
#endif
#ifndef MNN_OPENCL_BUFFER_CLOSED
#ifdef MNN_SUPPORT_INTEL_SUBGROUP
const char* binary_subgroup_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define PI 3.141592653589f\n"
"__kernel void binary_buf_c4_c4_c4(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __global INPUT_TYPE* input0,__global INPUT_TYPE* input1,__global OUTPUT_TYPE* output,\n"
" __private const int4 shape,//[N,H,W,C]\n"
" __private const int2 isFull,\n"
" __private const int activationType,\n"
" __private const int input0_pad_left,__private const int input0_pad_right,\n"
" __private const int input1_pad_left,__private const int input1_pad_right,\n"
" __private const int output_pad_left,__private const int output_pad_right) {\n"
" if (get_global_id(0) >= global_dim0 || get_global_id(1) >= global_dim1 || get_global_id(2) >= global_dim2) \n"
" return;\n"
" const int channel4=(shape.w+3)/4;\n"
" const int w_idx=get_global_id(0) % shape.z;\n"
" const int h_idx=get_global_id(0)/shape.z;\n"
" const int batch_idx=get_global_id(2);\n"
" const int channel_idx=get_global_id(1);\n"
" const int offset=(((batch_idx+channel_idx*shape.x)*shape.y+h_idx)*shape.z+w_idx)*4;\n"
" \n"
" float4 in0=convert_float4(vload4(0,input0+offset*isFull.x));\n"
" float4 in1=convert_float4(vload4(0,input1+offset*isFull.y));\n"
" if(isFull.x == 0) {\n"
" in0=(float4)(in0.x,in0.x,in0.x,in0.x);\n"
" }\n"
" if(isFull.y == 0) {\n"
" in1=(float4)(in1.x,in1.x,in1.x,in1.x);\n"
" }\n"
" \n"
" float4 out=OPERATOR;\n"
" \n"
" if(activationType == 1) {\n"
" out=fmax(out,(float4)0);\n"
" }\n"
" vstore4(CONVERT_OUTPUT4(out),0,output+offset);\n"
"}\n"
"__kernel void binary_buf_c4_c4_c16(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __global INPUT_TYPE* input0,__global INPUT_TYPE* input1,__global OUTPUT_TYPE* output,\n"
" __private const int4 shape,//[N,H,W,C]\n"
" __private const int2 isFull,\n"
" __private const int activationType,\n"
" __private const int input0_pad_left,__private const int input0_pad_right,\n"
" __private const int input1_pad_left,__private const int input1_pad_right,\n"
" __private const int output_pad_left,__private const int output_pad_right) {\n"
" if (get_global_id(0) >= global_dim0 || get_global_id(1) >= global_dim1 || get_global_id(2) >= global_dim2) \n"
" return;\n"
" const int channel4=(shape.w+3)/4;\n"
" const int channel16=(shape.w+15)/16;\n"
" const int w_idx=get_global_id(0) % shape.z;\n"
" const int h_idx=get_global_id(0)/shape.z;\n"
" const int batch_idx=get_global_id(2);\n"
" const int channel_idx=get_global_id(1);\n"
" const int dst_width=shape.z+output_pad_left+output_pad_right;\n"
" const int channe_out_idx=channel_idx >> 2;\n"
" const int offset=(((batch_idx+channel_idx*shape.x)*shape.y+h_idx)*shape.z+w_idx)*4;\n"
" const int dst_offset=(((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*dst_width+w_idx+output_pad_left)*16+(channel_idx % 4)*4;\n"
" \n"
" float4 in0=convert_float4(vload4(0,input0+offset*isFull.x));\n"
" float4 in1=convert_float4(vload4(0,input1+offset*isFull.y));\n"
" if(isFull.x == 0) {\n"
" in0=(float4)(in0.x,in0.x,in0.x,in0.x);\n"
" }\n"
" if(isFull.y == 0) {\n"
" in1=(float4)(in1.x,in1.x,in1.x,in1.x);\n"
" }\n"
" float4 out=OPERATOR;\n"
" if(activationType == 1) {\n"
" out=fmax(out,(float4)0);\n"
" }\n"
" vstore4(CONVERT_OUTPUT4(out),0,output+dst_offset);\n"
" if(w_idx == 0){\n"
" int pad_offset=(((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*dst_width)*16+(channel_idx % 4)*4;\n"
" for(int i=0; i<output_pad_left; ++i){\n"
" vstore4((OUTPUT_TYPE4)0,0,output+pad_offset+i*16);\n"
" }\n"
" pad_offset += (shape.z+output_pad_left)*16;\n"
" for(int i=0; i<output_pad_right; ++i){\n"
" vstore4((OUTPUT_TYPE4)0,0,output+pad_offset+i*16);\n"
" }\n"
" }\n"
"}\n"
"__kernel void binary_buf_c4_c16_c4(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __global INPUT_TYPE* input0,__global INPUT_TYPE* input1,__global OUTPUT_TYPE* output,\n"
" __private const int4 shape,//[N,H,W,C]\n"
" __private const int2 isFull,\n"
" __private const int activationType,\n"
" __private const int input0_pad_left,__private const int input0_pad_right,\n"
" __private const int input1_pad_left,__private const int input1_pad_right,\n"
" __private const int output_pad_left,__private const int output_pad_right) {\n"
" if (get_global_id(0) >= global_dim0 || get_global_id(1) >= global_dim1 || get_global_id(2) >= global_dim2) \n"
" return;\n"
" const int channel4=(shape.w+3)/4;\n"
" const int channel16=(shape.w+15)/16;\n"
" const int w_idx=get_global_id(0) % shape.z;\n"
" const int h_idx=get_global_id(0)/shape.z;\n"
" const int batch_idx=get_global_id(2);\n"
" const int channel_idx=get_global_id(1);\n"
" const int src_width=shape.z+input1_pad_left+input1_pad_right;\n"
" const int channe_out_idx=channel_idx >> 2;\n"
" const int offset0=(((batch_idx+channel_idx*shape.x)*shape.y+h_idx)*shape.z+w_idx)*4;\n"
" const int offset1=(((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*src_width+w_idx+input1_pad_left)*16+(channel_idx % 4)*4;\n"
" float4 in0=convert_float4(vload4(0,input0+offset0*isFull.x));\n"
" float4 in1=convert_float4(vload4(0,input1+offset1*isFull.y));\n"
" if(isFull.x == 0) {\n"
" in0=(float4)(in0.x,in0.x,in0.x,in0.x);\n"
" }\n"
" if(isFull.y == 0) {\n"
" in1=(float4)(in1.x,in1.x,in1.x,in1.x);\n"
" }\n"
" float4 out=OPERATOR;\n"
" if(activationType == 1) {\n"
" out=fmax(out,(float4)0);\n"
" }\n"
" vstore4(CONVERT_OUTPUT4(out),0,output+offset0);\n"
"}\n"
"__kernel void binary_buf_c16_c4_c4(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __global INPUT_TYPE* input0,__global INPUT_TYPE* input1,__global OUTPUT_TYPE* output,\n"
" __private const int4 shape,//[N,H,W,C]\n"
" __private const int2 isFull,\n"
" __private const int activationType,\n"
" __private const int input0_pad_left,__private const int input0_pad_right,\n"
" __private const int input1_pad_left,__private const int input1_pad_right,\n"
" __private const int output_pad_left,__private const int output_pad_right) {\n"
" if (get_global_id(0) >= global_dim0 || get_global_id(1) >= global_dim1 || get_global_id(2) >= global_dim2) \n"
" return;\n"
" const int channel4=(shape.w+3)/4;\n"
" const int channel16=(shape.w+15)/16;\n"
" const int w_idx=get_global_id(0) % shape.z;\n"
" const int h_idx=get_global_id(0)/shape.z;\n"
" const int batch_idx=get_global_id(2);\n"
" const int channel_idx=get_global_id(1);\n"
" const int src_width=shape.z+input0_pad_left+input0_pad_right;\n"
" const int channe_out_idx=channel_idx >> 2;\n"
" const int offset1=(((batch_idx+channel_idx*shape.x)*shape.y+h_idx)*shape.z+w_idx)*4;\n"
" const int offset0=(((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*src_width+w_idx+input0_pad_left)*16+(channel_idx % 4)*4;\n"
" \n"
" float4 in0=convert_float4(vload4(0,input0+offset0*isFull.x));\n"
" float4 in1=convert_float4(vload4(0,input1+offset1*isFull.y));\n"
" if(isFull.x == 0) {\n"
" in0=(float4)(in0.x,in0.x,in0.x,in0.x);\n"
" }\n"
" if(isFull.y == 0) {\n"
" in1=(float4)(in1.x,in1.x,in1.x,in1.x);\n"
" }\n"
" float4 out=OPERATOR;\n"
" \n"
" if(activationType == 1) {\n"
" out=fmax(out,(float4)0);\n"
" }\n"
" vstore4(CONVERT_OUTPUT4(out),0,output+offset1);\n"
"}\n"
"__kernel void binary_buf_c4_c16_c16(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __global INPUT_TYPE* input0,__global INPUT_TYPE* input1,__global OUTPUT_TYPE* output,\n"
" __private const int4 shape,//[N,H,W,C]\n"
" __private const int2 isFull,\n"
" __private const int activationType,\n"
" __private const int input0_pad_left,__private const int input0_pad_right,\n"
" __private const int input1_pad_left,__private const int input1_pad_right,\n"
" __private const int output_pad_left,__private const int output_pad_right) {\n"
" if (get_global_id(0) >= global_dim0 || get_global_id(1) >= global_dim1 || get_global_id(2) >= global_dim2) \n"
" return;\n"
" const int channel4=(shape.w+3)/4;\n"
" const int channel16=(shape.w+15)/16;\n"
" const int w_idx=get_global_id(0) % shape.z;\n"
" const int h_idx=get_global_id(0)/shape.z;\n"
" const int batch_idx=get_global_id(2);\n"
" const int channel_idx=get_global_id(1);\n"
" const int src_width=shape.z+input1_pad_left+input1_pad_right;\n"
" const int dst_width=shape.z+output_pad_left+output_pad_right;\n"
" const int channe_out_idx=channel_idx >> 2;\n"
" const int offset0=(((batch_idx+channel_idx*shape.x)*shape.y+h_idx)*shape.z+w_idx)*4;\n"
" const int offset1=(((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*src_width+w_idx+input1_pad_left)*16+(channel_idx % 4)*4;\n"
" const int dst_offset=(((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*dst_width+w_idx+output_pad_left)*16+(channel_idx % 4)*4;\n"
" \n"
" float4 in0=convert_float4(vload4(0,input0+offset0*isFull.x));\n"
" float4 in1=convert_float4(vload4(0,input1+offset1*isFull.y));\n"
" if(isFull.x == 0) {\n"
" in0=(float4)(in0.x,in0.x,in0.x,in0.x);\n"
" }\n"
" if(isFull.y == 0) {\n"
" in1=(float4)(in1.x,in1.x,in1.x,in1.x);\n"
" }\n"
" float4 out=OPERATOR;\n"
" \n"
" if(activationType == 1) {\n"
" out=fmax(out,(float4)0);\n"
" }\n"
" vstore4(CONVERT_OUTPUT4(out),0,output+dst_offset);\n"
" if(w_idx == 0){\n"
" int pad_offset=(((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*dst_width)*16+(channel_idx % 4)*4;\n"
" for(int i=0; i<output_pad_left; ++i){\n"
" vstore4((OUTPUT_TYPE4)0,0,output+pad_offset+i*16);\n"
" }\n"
" pad_offset += (shape.z+output_pad_left)*16;\n"
" for(int i=0; i<output_pad_right; ++i){\n"
" vstore4((OUTPUT_TYPE4)0,0,output+pad_offset+i*16);\n"
" }\n"
" }\n"
"}\n"
"__kernel void binary_buf_c16_c4_c16(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __global INPUT_TYPE* input0,__global INPUT_TYPE* input1,__global OUTPUT_TYPE* output,\n"
" __private const int4 shape,//[N,H,W,C]\n"
" __private const int2 isFull,\n"
" __private const int activationType,\n"
" __private const int input0_pad_left,__private const int input0_pad_right,\n"
" __private const int input1_pad_left,__private const int input1_pad_right,\n"
" __private const int output_pad_left,__private const int output_pad_right) {\n"
" if (get_global_id(0) >= global_dim0 || get_global_id(1) >= global_dim1 || get_global_id(2) >= global_dim2) \n"
" return;\n"
" const int channel4=(shape.w+3)/4;\n"
" const int channel16=(shape.w+15)/16;\n"
" const int w_idx=get_global_id(0) % shape.z;\n"
" const int h_idx=get_global_id(0)/shape.z;\n"
" const int batch_idx=get_global_id(2);\n"
" const int channel_idx=get_global_id(1);\n"
" const int src_width=shape.z+input0_pad_left+input0_pad_right;\n"
" const int dst_width=shape.z+output_pad_left+output_pad_right;\n"
" const int channe_out_idx=channel_idx >> 2;\n"
" const int offset1=(((batch_idx+channel_idx*shape.x)*shape.y+h_idx)*shape.z+w_idx)*4;\n"
" const int offset0=(((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*src_width+w_idx+input0_pad_left)*16+(channel_idx % 4)*4;\n"
" const int dst_offset=(((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*dst_width+w_idx+output_pad_left)*16+(channel_idx % 4)*4;\n"
" \n"
" float4 in0=convert_float4(vload4(0,input0+offset0*isFull.x));\n"
" float4 in1=convert_float4(vload4(0,input1+offset1*isFull.y));\n"
" if(isFull.x == 0) {\n"
" in0=(float4)(in0.x,in0.x,in0.x,in0.x);\n"
" }\n"
" if(isFull.y == 0) {\n"
" in1=(float4)(in1.x,in1.x,in1.x,in1.x);\n"
" }\n"
" float4 out=OPERATOR;\n"
" \n"
" if(activationType == 1) {\n"
" out=fmax(out,(float4)0);\n"
" }\n"
" vstore4(CONVERT_OUTPUT4(out),0,output+dst_offset);\n"
" if(w_idx == 0){\n"
" int pad_offset=(((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*dst_width)*16+(channel_idx % 4)*4;\n"
" for(int i=0; i<output_pad_left; ++i){\n"
" vstore4((OUTPUT_TYPE4)0,0,output+pad_offset+i*16);\n"
" }\n"
" pad_offset += (shape.z+output_pad_left)*16;\n"
" for(int i=0; i<output_pad_right; ++i){\n"
" vstore4((OUTPUT_TYPE4)0,0,output+pad_offset+i*16);\n"
" }\n"
" }\n"
"}\n"
"__kernel void prelu_buf_c4_c4(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __global INPUT_TYPE* input0,__global INPUT_TYPE* input1,__global OUTPUT_TYPE* output,\n"
" __private const int4 shape,//[N,H,W,C]\n"
" __private const int input0_pad_left,__private const int input0_pad_right,\n"
" __private const int output_pad_left,__private const int output_pad_right\n"
" ) {\n"
" if (get_global_id(0) >= global_dim0 || get_global_id(1) >= global_dim1 || get_global_id(2) >= global_dim2) \n"
" return;\n"
" const int channel4=(shape.w+3)/4;\n"
" const int w_idx=get_global_id(0) % shape.z;\n"
" const int h_idx=get_global_id(0)/shape.z;\n"
" const int batch_idx=get_global_id(2);\n"
" const int channel_idx=get_global_id(1);\n"
" \n"
" const int offset0=(((batch_idx+channel_idx*shape.x)*shape.y+h_idx)*shape.z+w_idx)*4;\n"
" const int offset1=channel_idx*4;\n"
" \n"
" float4 in0=convert_float4(vload4(0,input0+offset0));\n"
" float4 in1=convert_float4(vload4(0,input1+offset1));\n"
" float4 out=OPERATOR;\n"
" vstore4(CONVERT_OUTPUT4(out),0,output+offset0);\n"
"}\n"
"__kernel void prelu_buf_c4_c16(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __global INPUT_TYPE* input0,__global INPUT_TYPE* input1,__global OUTPUT_TYPE* output,\n"
" __private const int4 shape,//[N,H,W,C]\n"
" __private const int input0_pad_left,__private const int input0_pad_right,\n"
" __private const int output_pad_left,__private const int output_pad_right\n"
" ) {\n"
" if (get_global_id(0) >= global_dim0 || get_global_id(1) >= global_dim1 || get_global_id(2) >= global_dim2) \n"
" return;\n"
" const int channel4=(shape.w+3)/4;\n"
" const int channel16=(shape.w+15)/16;\n"
" const int w_idx=get_global_id(0) % shape.z;\n"
" const int h_idx=get_global_id(0)/shape.z;\n"
" const int batch_idx=get_global_id(2);\n"
" const int channel_idx=get_global_id(1);\n"
" const int dst_width=shape.z+output_pad_left+output_pad_right;\n"
" const int channe_out_idx=channel_idx >> 2;\n"
" \n"
" const int offset0=(((batch_idx+channel_idx*shape.x)*shape.y+h_idx)*shape.z+w_idx)*4;\n"
" const int offset1=channel_idx*4;\n"
" const int offset=(((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*dst_width+w_idx+output_pad_left)*16+(channel_idx % 4)*4;\n"
" float4 in0=convert_float4(vload4(0,input0+offset0));\n"
" float4 in1=convert_float4(vload4(0,input1+offset1));\n"
" float4 out=OPERATOR;\n"
" \n"
" vstore4(CONVERT_OUTPUT4(out),0,output+offset);\n"
" if(w_idx == 0){\n"
" int pad_offset=(((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*dst_width)*16+(channel_idx % 4)*4;\n"
" for(int i=0; i<output_pad_left; ++i){\n"
" vstore4((OUTPUT_TYPE4)0,0,output+pad_offset+i*16);\n"
" }\n"
" pad_offset += (shape.z+output_pad_left)*16;\n"
" for(int i=0; i<output_pad_right; ++i){\n"
" vstore4((OUTPUT_TYPE4)0,0,output+pad_offset+i*16);\n"
" }\n"
" }\n"
"}\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void prelu_buf_c16_c16(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __global INPUT_TYPE* input0,__global INPUT_TYPE* input1,__global OUTPUT_TYPE* output,\n"
" __private const int4 shape,//[N,H,W,C]\n"
" __private const int input0_pad_left,__private const int input0_pad_right,\n"
" __private const int output_pad_left,__private const int output_pad_right) {\n"
" const int channel16=(shape.w+15)/16;\n"
" const int width_pack=(shape.z+3)/4;\n"
" const int w_idx=(get_global_id(0) % width_pack) << 2;\n"
" const int h_idx=get_global_id(0)/width_pack;\n"
" const int batch_idx=get_global_id(2);\n"
" const int channel_idx=get_group_id(1);\n"
" const int sglid=get_sub_group_local_id();\n"
" const int src_width=shape.z+input0_pad_left+input0_pad_right;\n"
" const int dst_width=shape.z+output_pad_left+output_pad_right;\n"
" const int offset0=(((batch_idx*channel16+channel_idx)*shape.y+h_idx)*src_width+w_idx+input0_pad_left)*16;\n"
" const int offset1=channel_idx*16;\n"
" const int offset=(((batch_idx*channel16+channel_idx)*shape.y+h_idx)*dst_width+w_idx+output_pad_left)*16;\n"
" float4 in0=convert_float4(AS_INPUT_DATA4(INTEL_SUB_GROUP_READ4((__global INTEL_DATA*)(input0+offset0))));\n"
" float4 in1=(float4)(AS_INPUT_DATA(INTEL_SUB_GROUP_READ((__global INTEL_DATA*)(input1+offset1))));\n"
" \n"
" float4 out=OPERATOR;\n"
" {\n"
" if (w_idx+4>shape.z) {\n"
" for (int i=0; i<shape.z % 4; i++) {\n"
" output[offset+i*16+sglid]=(OUTPUT_TYPE)out[i];\n"
" }\n"
" }else{\n"
" INTEL_SUB_GROUP_WRITE4((__global INTEL_DATA*)(output+offset),AS_OUTPUT_DATA4(CONVERT_OUTPUT4(out)));\n"
" }\n"
" }\n"
" if(w_idx == 0){\n"
" int pad_offset=(((batch_idx*channel16+channel_idx)*shape.y+h_idx)*dst_width)*16+sglid;\n"
" for(int i=0; i<output_pad_left; ++i){\n"
" output[pad_offset+i*16]=(OUTPUT_TYPE)0;\n"
" }\n"
" pad_offset += (shape.z+output_pad_left)*16;\n"
" for(int i=0; i<output_pad_right; ++i){\n"
" output[pad_offset+i*16]=(OUTPUT_TYPE)0;\n"
" }\n"
" }\n"
"}\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void prelu_buf_c16_c4(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __global INPUT_TYPE* input0,__global INPUT_TYPE* input1,__global OUTPUT_TYPE* output,\n"
" __private const int4 shape,//[N,H,W,C]\n"
" __private const int input0_pad_left,__private const int input0_pad_right,\n"
" __private const int output_pad_left,__private const int output_pad_right) {\n"
" const int channel4=(shape.w+3)/4;\n"
" const int channel16=(shape.w+15)/16;\n"
" const int width_pack=(shape.z+3)/4;\n"
" const int w_idx=(get_global_id(0) % width_pack) << 2;\n"
" const int h_idx=get_global_id(0)/width_pack;\n"
" const int batch_idx=get_global_id(2);\n"
" const int channel_idx=get_group_id(1);\n"
" const int sglid=get_sub_group_local_id();\n"
" const int src_width=shape.z+input0_pad_left+input0_pad_right;\n"
" const int batch_width_height=shape.x*shape.z*shape.y*4;\n"
" const int offset0=(((batch_idx*channel16+channel_idx)*shape.y+h_idx)*src_width+w_idx+input0_pad_left)*16;\n"
" const int offset1=channel_idx*16;\n"
" const int offset=(((batch_idx+(channel_idx<<2)*shape.x)*shape.y+h_idx)*shape.z+w_idx)*4;\n"
" float4 in0=convert_float4(AS_INPUT_DATA4(INTEL_SUB_GROUP_READ4((__global INTEL_DATA*)(input0+offset0))));\n"
" float4 in1=(float4)(AS_INPUT_DATA(INTEL_SUB_GROUP_READ((__global INTEL_DATA*)(input1+offset1))));\n"
" \n"
" float4 out=OPERATOR;\n"
" const int lid_x=sglid % 4;\n"
" const int lid_y=sglid/4;\n"
" int block_size=w_idx+4>shape.z ? (shape.z % 4) : 4;\n"
" for (int i=0; i<block_size; i++) {\n"
" output[offset+i*4+lid_y*batch_width_height+lid_x]=(OUTPUT_TYPE)out[i];\n"
" }\n"
"}\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void binary_buf_c16_c16_c16(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __global INPUT_TYPE* input0,__global INPUT_TYPE* input1,__global OUTPUT_TYPE* output,\n"
" __private const int4 shape,//[N,H,W,C4]\n"
" __private const int2 isFull,\n"
" __private const int activationType,\n"
" __private const int input0_pad_left,__private const int input0_pad_right,\n"
" __private const int input1_pad_left,__private const int input1_pad_right,\n"
" __private const int output_pad_left,__private const int output_pad_right) {\n"
" const int channel16=(shape.w+15)/16;\n"
" const int width_pack=(shape.z+3)/4;\n"
" const int w_idx=(get_global_id(0) % width_pack) << 2;\n"
" const int h_idx=get_global_id(0)/width_pack;\n"
" const int batch_idx=get_global_id(2);\n"
" const int channel_idx=get_group_id(1);\n"
" const int sglid=get_sub_group_local_id();\n"
" const int src0_width=shape.z+input0_pad_left+input0_pad_right;\n"
" const int src1_width=shape.z+input1_pad_left+input1_pad_right;\n"
" const int dst_width=shape.z+output_pad_left+output_pad_right;\n"
" const int offset0=(((batch_idx*channel16+channel_idx)*shape.y+h_idx)*src0_width+w_idx+input0_pad_left)*16;\n"
" const int offset1=(((batch_idx*channel16+channel_idx)*shape.y+h_idx)*src1_width+w_idx+input1_pad_left)*16;\n"
" const int offset=(((batch_idx*channel16+channel_idx)*shape.y+h_idx)*dst_width+w_idx+output_pad_left)*16;\n"
" float4 in0=isFull.x ? convert_float4(AS_INPUT_DATA4(INTEL_SUB_GROUP_READ4((__global INTEL_DATA*)(input0+offset0)))) : (float4)(input0[0]);\n"
" float4 in1=isFull.y ? convert_float4(AS_INPUT_DATA4(INTEL_SUB_GROUP_READ4((__global INTEL_DATA*)(input1+offset1)))) : (float4)(input1[0]);\n"
" \n"
" float4 out=OPERATOR;\n"
" if(activationType == 1) {\n"
" out=fmax(out,(float4)0);\n"
" }\n"
" {\n"
" if (w_idx+4>shape.z) {\n"
" for (int i=0; i<shape.z % 4; i++) {\n"
" output[offset+i*16+sglid]=(OUTPUT_TYPE)out[i];\n"
" }\n"
" }else{\n"
" INTEL_SUB_GROUP_WRITE4((__global INTEL_DATA*)(output+offset),AS_OUTPUT_DATA4(CONVERT_OUTPUT4(out)));\n"
" }\n"
" }\n"
" if(w_idx == 0){\n"
" int pad_offset=(((batch_idx*channel16+channel_idx)*shape.y+h_idx)*dst_width)*16+sglid;\n"
" for(int i=0; i<output_pad_left; ++i){\n"
" output[pad_offset+i*16]=(OUTPUT_TYPE)0;\n"
" }\n"
" pad_offset += (shape.z+output_pad_left)*16;\n"
" for(int i=0; i<output_pad_right; ++i){\n"
" output[pad_offset+i*16]=(OUTPUT_TYPE)0;\n"
" }\n"
" }\n"
"}\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void binary_buf_c16_c16_c4(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __global INPUT_TYPE* input0,__global INPUT_TYPE* input1,__global OUTPUT_TYPE* output,\n"
" __private const int4 shape,//[N,H,W,C4]\n"
" __private const int2 isFull,\n"
" __private const int activationType,\n"
" __private const int input0_pad_left,__private const int input0_pad_right,\n"
" __private const int input1_pad_left,__private const int input1_pad_right,\n"
" __private const int output_pad_left,__private const int output_pad_right) {\n"
" const int channel16=(shape.w+15)/16;\n"
" const int channel4=(shape.w+3)/4;\n"
" const int width_pack=(shape.z+3)/4;\n"
" const int w_idx=(get_global_id(0) % width_pack) << 2;\n"
" const int h_idx=get_global_id(0)/width_pack;\n"
" const int batch_idx=get_global_id(2);\n"
" const int channel_idx=get_group_id(1);\n"
" const int sglid=get_sub_group_local_id();\n"
" const int src0_width=shape.z+input0_pad_left+input0_pad_right;\n"
" const int src1_width=shape.z+input1_pad_left+input1_pad_right;\n"
" const int batch_width_height=shape.x*shape.z*shape.y*4;\n"
" const int offset0=(((batch_idx*channel16+channel_idx)*shape.y+h_idx)*src0_width+w_idx+input0_pad_left)*16;\n"
" const int offset1=(((batch_idx*channel16+channel_idx)*shape.y+h_idx)*src1_width+w_idx+input1_pad_left)*16;\n"
" const int offset=(((batch_idx+(channel_idx << 2)*shape.x)*shape.y+h_idx)*shape.z+w_idx)*4;\n"
" float4 in0=isFull.x ? convert_float4(AS_INPUT_DATA4(INTEL_SUB_GROUP_READ4((__global INTEL_DATA*)(input0+offset0)))) : (float4)(input0[0]);\n"
" float4 in1=isFull.y ? convert_float4(AS_INPUT_DATA4(INTEL_SUB_GROUP_READ4((__global INTEL_DATA*)(input1+offset1)))) : (float4)(input1[0]);\n"
" \n"
" float4 out=OPERATOR;\n"
" if(activationType == 1) {\n"
" out=fmax(out,(float4)0);\n"
" }\n"
" const int lid_x=sglid % 4;\n"
" const int lid_y=sglid/4;\n"
" int block_size=w_idx+4>shape.z ? (shape.z % 4) : 4;\n"
" for (int i=0; i<block_size; i++) {\n"
" output[offset+i*4+lid_y*batch_width_height+lid_x]=(OUTPUT_TYPE)out[i];\n"
" }\n"
"}\n"
;
#endif
#endif
#ifndef MNN_OPENCL_BUFFER_CLOSED
#ifdef MNN_SUPPORT_INTEL_SUBGROUP
const char* depthwise_conv2d_subgroup_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void depthwise_conv_2d_buf_c16_c16(\n"
" __global FLOAT* input,\n"
" __global FLOAT* output,\n"
" __global FLOAT* weights,\n"
" __global FLOAT* biases,\n"
" __private const int inputHeight,\n"
" __private const int inputWidth,\n"
" __private const int Channel,\n"
" __private const int Batch,\n"
" __private const int input_pad_left,\n"
" __private const int input_pad_right,\n"
" __private const int outputHeight,\n"
" __private const int outputWidth,\n"
" __private const int output_pad_left,\n"
" __private const int output_pad_right,\n"
" __private const int pad_w,\n"
" __private const int pad_h\n"
") {\n"
" const int x_blocks=(outputWidth+7)/8;\n"
" const int sglid=get_sub_group_local_id();\n"
" const int b=get_global_id(2);\n"
" const int xy=get_global_id(0);\n"
" const int x=(xy % x_blocks)*8;\n"
" const int y=(xy/x_blocks);\n"
" const int c=get_group_id(1);\n"
" const int input_x=x*STRIDE_WIDTH-pad_w;\n"
" const int input_y=y*STRIDE_HEIGHT-pad_h;\n"
" const int channel_pack=((Channel+15)/16);\n"
" const uint input_x_pitch=16;\n"
" const uint input_y_pitch=input_x_pitch*(inputWidth+input_pad_left+input_pad_right);\n"
" const uint input_fs_pitch=input_y_pitch*(inputHeight);\n"
" const uint input_b_pitch=input_fs_pitch*channel_pack;\n"
" const uint input_offset=b*input_b_pitch +\n"
" c*input_fs_pitch+\n"
" input_y*input_y_pitch +\n"
" (input_x+input_pad_left)*input_x_pitch;\n"
" const uint output_x_pitch=16;\n"
" const uint output_y_pitch=output_x_pitch*(outputWidth+output_pad_left+output_pad_right);\n"
" const uint output_fs_pitch=output_y_pitch*outputHeight;\n"
" const uint output_b_pitch=output_fs_pitch*channel_pack;\n"
" const uint output_offset=b*output_b_pitch +\n"
" c*output_fs_pitch +\n"
" y*output_y_pitch +\n"
" (x+output_pad_left)*output_x_pitch;\n"
" const uint filter_x_pitch=16;\n"
" const uint filter_y_pitch=filter_x_pitch*FILTER_WIDTH;\n"
" const uint filter_is_pitch=filter_y_pitch*FILTER_HEIGHT;\n"
" const uint filter_offset=c*filter_is_pitch;\n"
"#ifdef MNN_SUPPORT_FP16\n"
" COMPUTE_FLOAT8 dst=(COMPUTE_FLOAT8)(as_half(intel_sub_group_block_read_us((__global ushort*)(biases+c*16))));\n"
" for(int i=0; i<FILTER_HEIGHT; ++i){\n"
" if ((input_y+i*DILATION_HEIGHT)<0 || (input_y+i*DILATION_HEIGHT) >= inputHeight)\n"
" continue;\n"
" for(int j=0; j<FILTER_WIDTH; ++j){\n"
" COMPUTE_FLOAT wei=as_half(intel_sub_group_block_read_us((__global ushort*)(weights+filter_offset+i*filter_y_pitch+j*filter_x_pitch)));\n"
" for(int k=0; k<8; ++k){\n"
" COMPUTE_FLOAT src=as_half(intel_sub_group_block_read_us((__global ushort*)(input+input_offset+i*DILATION_HEIGHT*input_y_pitch+(j*DILATION_WIDTH+k*STRIDE_WIDTH)*input_x_pitch)));\n"
" dst[k]=mad(src,wei,dst[k]);\n"
" }\n"
" }\n"
" }\n"
" \n"
"#else\n"
" COMPUTE_FLOAT8 dst=(COMPUTE_FLOAT8)(as_float(intel_sub_group_block_read((__global uint*)(biases+c*16))));\n"
" for(int i=0; i<FILTER_HEIGHT; ++i){\n"
" if ((input_y+i*DILATION_HEIGHT)<0 || (input_y+i*DILATION_HEIGHT) >= inputHeight)\n"
" continue;\n"
" for(int j=0; j<FILTER_WIDTH; ++j){\n"
" COMPUTE_FLOAT wei=as_float(intel_sub_group_block_read((__global ushort*)(weights+filter_offset+i*filter_y_pitch+j*filter_x_pitch)));\n"
" for(int k=0; k<8; ++k){\n"
" COMPUTE_FLOAT src=as_float(intel_sub_group_block_read((__global ushort*)(input+input_offset+i*DILATION_HEIGHT*input_y_pitch+(j*DILATION_WIDTH+k*STRIDE_WIDTH)*input_x_pitch)));\n"
" dst[k]=mad(src,wei,dst[k]);\n"
" }\n"
" }\n"
" }\n"
"#endif\n"
"#ifdef RELU\n"
" dst=fmax(dst,(COMPUTE_FLOAT8)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" dst=clamp(dst,(COMPUTE_FLOAT8)0,(COMPUTE_FLOAT8)6);\n"
"#endif\n"
" \n"
" for (int i=0; i<8 && (x+i)<outputWidth; i++) {\n"
"#ifdef MNN_SUPPORT_FP16\n"
" intel_sub_group_block_write_us((__global ushort*)(output+output_offset+i*output_x_pitch),as_ushort((FLOAT)dst[i]));\n"
"#else\n"
" intel_sub_group_block_write((__global uint*)(output+output_offset+i*output_x_pitch),as_uint((FLOAT)dst[i]));\n"
"#endif\n"
" }\n"
" if(x == 0){\n"
" uint pad_offset=b*output_b_pitch+c*output_fs_pitch+y*output_y_pitch;\n"
" for(int i=0; i<output_pad_left; ++i){\n"
" output[pad_offset+i*output_x_pitch+sglid]=0;\n"
" }\n"
" pad_offset += (outputWidth+output_pad_left)*output_x_pitch;\n"
" for(int i=0; i<output_pad_right; ++i){\n"
" output[pad_offset+i*output_x_pitch+sglid]=0;\n"
" }\n"
" }\n"
"}\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void depthwise_conv_2d_buf_c16_c4(\n"
" __global FLOAT* input,\n"
" __global FLOAT* output,\n"
" __global FLOAT* weights,\n"
" __global FLOAT* biases,\n"
" __private const int inputHeight,\n"
" __private const int inputWidth,\n"
" __private const int Channel,\n"
" __private const int Batch,\n"
" __private const int input_pad_left,\n"
" __private const int input_pad_right,\n"
" __private const int outputHeight,\n"
" __private const int outputWidth,\n"
" __private const int output_pad_left,\n"
" __private const int output_pad_right,\n"
" __private const int pad_w,\n"
" __private const int pad_h\n"
") {\n"
" const int x_blocks=(outputWidth+7)/8;\n"
" const int sglid=get_sub_group_local_id();\n"
" const int b=get_global_id(2);\n"
" const int xy=get_global_id(0);\n"
" const int x=(xy % x_blocks)*8;\n"
" const int y=(xy/x_blocks);\n"
" const int c=get_group_id(1);\n"
" const int input_x=x*STRIDE_WIDTH-pad_w;\n"
" const int input_y=y*STRIDE_HEIGHT-pad_h;\n"
" const int channel_pack=((Channel+15)/16);\n"
" const uint input_x_pitch=16;\n"
" const uint input_y_pitch=input_x_pitch*(inputWidth+input_pad_left+input_pad_right);\n"
" const uint input_fs_pitch=input_y_pitch*(inputHeight);\n"
" const uint input_b_pitch=input_fs_pitch*channel_pack;\n"
" const uint input_offset=b*input_b_pitch +\n"
" c*input_fs_pitch+\n"
" input_y*input_y_pitch +\n"
" (input_x+input_pad_left)*input_x_pitch;\n"
" const uint output_x_pitch=4;\n"
" const uint output_y_pitch=output_x_pitch*outputWidth;\n"
" const uint output_fs_pitch=output_y_pitch*outputHeight;\n"
" const uint output_b_pitch=output_fs_pitch*Batch;\n"
" const uint output_offset=(c << 2)*output_b_pitch +\n"
" b*output_fs_pitch +\n"
" y*output_y_pitch +\n"
" x*output_x_pitch;\n"
" const uint filter_x_pitch=16;\n"
" const uint filter_y_pitch=filter_x_pitch*FILTER_WIDTH;\n"
" const uint filter_is_pitch=filter_y_pitch*FILTER_HEIGHT;\n"
" const uint filter_offset=c*filter_is_pitch;\n"
"#ifdef MNN_SUPPORT_FP16\n"
" COMPUTE_FLOAT8 dst=(COMPUTE_FLOAT8)(as_half(intel_sub_group_block_read_us((__global ushort*)(biases+c*16))));\n"
" for(int i=0; i<FILTER_HEIGHT; ++i){\n"
" if ((input_y+i*DILATION_HEIGHT)<0 || (input_y+i*DILATION_HEIGHT) >= inputHeight)\n"
" continue;\n"
" for(int j=0; j<FILTER_WIDTH; ++j){\n"
" COMPUTE_FLOAT wei=as_half(intel_sub_group_block_read_us((__global ushort*)(weights+filter_offset+i*filter_y_pitch+j*filter_x_pitch)));\n"
" for(int k=0; k<8; ++k){\n"
" COMPUTE_FLOAT src=as_half(intel_sub_group_block_read_us((__global ushort*)(input+input_offset+i*DILATION_HEIGHT*input_y_pitch+(j*DILATION_WIDTH+k*STRIDE_WIDTH)*input_x_pitch)));\n"
" dst[k]=mad(src,wei,dst[k]);\n"
" }\n"
" }\n"
" }\n"
" \n"
"#else\n"
" COMPUTE_FLOAT8 dst=(COMPUTE_FLOAT8)(as_float(intel_sub_group_block_read((__global uint*)(biases+c*16))));\n"
" for(int i=0; i<FILTER_HEIGHT; ++i){\n"
" if ((input_y+i*DILATION_HEIGHT)<0 || (input_y+i*DILATION_HEIGHT) >= inputHeight)\n"
" continue;\n"
" for(int j=0; j<FILTER_WIDTH; ++j){\n"
" COMPUTE_FLOAT wei=as_float(intel_sub_group_block_read((__global ushort*)(weights+filter_offset+i*filter_y_pitch+j*filter_x_pitch)));\n"
" for(int k=0; k<8; ++k){\n"
" COMPUTE_FLOAT src=as_float(intel_sub_group_block_read((__global ushort*)(input+input_offset+i*DILATION_HEIGHT*input_y_pitch+(j*DILATION_WIDTH+k*STRIDE_WIDTH)*input_x_pitch)));\n"
" dst[k]=mad(src,wei,dst[k]);\n"
" }\n"
" }\n"
" }\n"
"#endif\n"
"#ifdef RELU\n"
" dst=fmax(dst,(COMPUTE_FLOAT8)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" dst=clamp(dst,(COMPUTE_FLOAT8)0,(COMPUTE_FLOAT8)6);\n"
"#endif\n"
" const uint lid_x=sglid % 4;\n"
" const uint lid_y=sglid/4;\n"
" for (int i=0; i<8 && (x+i)<outputWidth; i++) {\n"
" output[output_offset+lid_y*output_b_pitch+i*output_x_pitch+lid_x]=dst[i];\n"
" }\n"
"}\n"
;
#endif
#endif
const char* nearest = 
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void interp(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,__write_only image2d_t output,\n"
" __private const float height_scale,__private const float width_scale,\n"
" __private const float height_offset,__private const float width_offset,\n"
" __private const int input_height,__private const int input_width,\n"
" __private const int out_height) {\n"
" const int output_channel_block_idx=get_global_id(0);\n"
" const int output_width_block_idx=get_global_id(1);\n"
" const int output_batch_height_block_idx=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(output_channel_block_idx,output_width_block_idx,output_batch_height_block_idx);\n"
" const int output_channel_block_idxs=global_size_dim0;\n"
" const int output_width=global_size_dim1;\n"
" const int output_batch_idx=output_batch_height_block_idx/out_height;\n"
" const int output_height_idx=output_batch_height_block_idx % out_height;\n"
" const float scale_height=output_height_idx*height_scale+height_offset;\n"
" const float scale_width=output_width_block_idx*width_scale+width_offset;\n"
"#ifdef USE_ROUND\n"
" const int height_lf=min(max(0,(int)floor(scale_height+0.499f)),input_height-1);\n"
" const int width_lf=min(max(0,(int)floor(scale_width+0.499f)),input_width-1);\n"
"#else\n"
" const int height_lf=min(max(0,(int)floor(scale_height)),input_height-1);\n"
" const int width_lf=min(max(0,(int)floor(scale_width)),input_width-1);\n"
"#endif\n"
" const int input_width_offset=mul24(output_channel_block_idx,input_width);\n"
" const int input_height_offset=mul24(output_batch_idx,input_height);\n"
" float4 out =\n"
" read_imagef(input,SAMPLER,(int2)(input_width_offset+width_lf,input_height_offset+height_lf));\n"
" const int out_image_w=mad24(output_channel_block_idx,output_width,output_width_block_idx);\n"
" const int out_image_h=mad24(output_batch_idx,out_height,output_height_idx);\n"
" write_imagef(output,(int2)(out_image_w,out_image_h),out);\n"
"}\n"
"__kernel void interp3D(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,__write_only image2d_t output,\n"
" __private const float depth_scale,__private const float height_scale,__private const float width_scale,\n"
" __private const float depth_offset,__private const float height_offset,__private const float width_offset,\n"
" __private const int input_depth,__private const int input_height,__private const int input_width,\n"
" __private const int out_depth,__private const int out_height) {\n"
" const int output_channel_block_idx=get_global_id(0);\n"
" const int output_height_width_block_idx=get_global_id(1);\n"
" const int output_batch_depth_block_idx=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(output_channel_block_idx,output_height_width_block_idx,output_batch_depth_block_idx);\n"
" const int output_channel_block_idxs=global_size_dim0;\n"
" const int output_tensor_height_width=global_size_dim1;\n"
" const int out_width=output_tensor_height_width/out_height;\n"
" const int output_batch_idx=output_batch_depth_block_idx/out_depth;\n"
" const int output_depth_idx=output_batch_depth_block_idx % out_depth;\n"
" const int output_height_idx=output_height_width_block_idx/out_height;\n"
" const int output_width_idx=output_height_width_block_idx % out_height;\n"
" const float scale_depth=output_depth_idx*depth_scale+depth_offset;\n"
" const float scale_height=output_height_idx*height_scale+height_offset;\n"
" const float scale_width=output_width_idx*width_scale+width_offset;\n"
" const int depth_lf=max(0,(int)floor(scale_depth));\n"
" const int height_lf=max(0,(int)floor(scale_height));\n"
" const int width_lf=max(0,(int)floor(scale_width));\n"
" const int input_tensor_width_height=mul24(input_width,input_height);\n"
" const int input_image_width_offset=mul24(output_channel_block_idx,input_tensor_width_height);\n"
" const int input_image_height_offset=mul24(output_batch_idx,input_depth);\n"
" float4 out=read_imagef(input,SAMPLER,\n"
" (int2)(input_image_width_offset+input_width*(height_offset+height_lf)+width_lf+width_offset,\n"
" input_image_height_offset+depth_lf+depth_offset));\n"
" const int output_image_width_offset=output_channel_block_idx*output_tensor_height_width;\n"
" const int output_image_height_offset=output_batch_idx*out_depth;\n"
" // TODO: out\n"
" const int out_image_w=output_image_width_offset+output_height_idx*out_width+output_width_idx;\n"
" const int out_image_h=output_image_height_offset+output_batch_idx*out_depth+output_depth_idx;\n"
" write_imagef(output,(int2)(out_image_w,out_image_h),out);\n"
"}\n"
;
#ifndef MNN_OPENCL_BUFFER_CLOSED
#ifdef MNN_SUPPORT_INTEL_SUBGROUP
const char* pooling_subgroup_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"__kernel void pooling_c4_c4(GLOBAL_SIZE_3_DIMS __global const FLOAT *input,\n"
" __private const int2 input_shape,\n"
" __private const int2 output_shape,\n"
" __private const int2 pad_shape,\n"
" __global FLOAT *output,\n"
" __global FLOAT *rediceOutput,\n"
" __private const int channel,\n"
" __private const int batch,\n"
" __private const int in_channel_block,\n"
" __private const int out_channel_block,\n"
" __private const int input_pad_left,\n"
" __private const int input_pad_right,\n"
" __private const int output_pad_left,\n"
" __private const int output_pad_right) {\n"
" \n"
" const int ow_idx=get_global_id(0);\n"
" const int b_oh_idx=get_global_id(1);\n"
" const int c_idx=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(ow_idx,b_oh_idx,c_idx);\n"
" \n"
" const int b_idx=b_oh_idx/output_shape.x;\n"
" const int oh_idx=b_oh_idx % output_shape.x;\n"
" const int iw_start=mad24(ow_idx,STRIDE_X,-pad_shape.y);\n"
" const int ih_start=mad24(oh_idx,STRIDE_Y,-pad_shape.x);\n"
" \n"
" #ifdef POOL_AVG\n"
" COMPUTE_FLOAT4 result=(COMPUTE_FLOAT4)(0);\n"
" const int inp_offset=(((b_idx+c_idx*batch)*input_shape.x+ih_start)*input_shape.y+iw_start+input_pad_left)*4;\n"
"#ifdef COUNT_INCLUDE_PADDING\n"
" int total_count=(min(ih_start+KERNEL_Y,input_shape.x+pad_shape.x)-ih_start)*(min(iw_start+KERNEL_X,input_shape.y+pad_shape.y)-iw_start);\n"
"#else\n"
" int total_count=0;\n"
"#endif\n"
" for(int kh=0; kh<KERNEL_Y; kh++) {\n"
" int ih_cur=ih_start+kh;\n"
" if(ih_cur<0 || ih_cur >= input_shape.x) {\n"
" continue;\n"
" }\n"
" for(int kw=0; kw<KERNEL_X; kw++) {\n"
" int iw_cur=iw_start+kw;\n"
" if(iw_cur<0 || iw_cur >= input_shape.y) {\n"
" continue;\n"
" }\n"
" COMPUTE_FLOAT4 inp_data=CONVERT_COMPUTE_FLOAT4(vload4(0,input+inp_offset+(kh*input_shape.y+kw)*4));\n"
" result += inp_data;\n"
"#ifndef COUNT_INCLUDE_PADDING\n"
" total_count++;\n"
"#endif\n"
" }\n"
" }\n"
" result=result/(COMPUTE_FLOAT4)(1.0*total_count);\n"
" #else\n"
" COMPUTE_FLOAT4 result=(COMPUTE_FLOAT4)(-FLT_MAX);\n"
" #if RETURN_REDICE\n"
" int4 redice=(int4)0;\n"
" #endif\n"
" const int inp_offset=(((b_idx+c_idx*batch)*input_shape.x+ih_start)*input_shape.y+iw_start+input_pad_left)*4;\n"
" for(int kh=0; kh<KERNEL_Y; kh++) {\n"
" int ih_cur=ih_start+kh;\n"
" if(ih_cur<0 || ih_cur >= input_shape.x) {\n"
" continue;\n"
" }\n"
" for(int kw=0; kw<KERNEL_X; kw++) {\n"
" int iw_cur=iw_start+kw;\n"
" if(iw_cur<0 || iw_cur >= input_shape.y) {\n"
" continue;\n"
" }\n"
" COMPUTE_FLOAT4 inp_data=CONVERT_COMPUTE_FLOAT4(vload4(0,input+inp_offset+(kh*input_shape.y+kw)*4));\n"
" #if RETURN_REDICE\n"
" redice=inp_data>result ? (int4)((ih_start+kh)*input_shape.y+iw_start+kw) : redice;\n"
" #endif\n"
" result=fmax(result,inp_data);\n"
" }\n"
" }\n"
" #endif\n"
" \n"
" const int out_offset=(((b_idx+c_idx*batch)*output_shape.x+oh_idx)* output_shape.y+ow_idx+output_pad_left)*4;\n"
" vstore4(CONVERT_FLOAT4(result),0,output+out_offset);\n"
" #if RETURN_REDICE\n"
" vstore4(CONVERT_FLOAT4(redice),0,rediceOutput+(((b_idx+c_idx*batch)*output_shape.x+oh_idx)* output_shape.y+ow_idx)*4);\n"
" #endif\n"
"}\n"
"__kernel void pooling_c4_c16(GLOBAL_SIZE_3_DIMS __global const FLOAT *input,\n"
" __private const int2 input_shape,\n"
" __private const int2 output_shape,\n"
" __private const int2 pad_shape,\n"
" __global FLOAT *output,\n"
" __global FLOAT *rediceOutput,\n"
" __private const int channel,\n"
" __private const int batch,\n"
" __private const int in_channel_block,\n"
" __private const int out_channel_block,\n"
" __private const int input_pad_left,\n"
" __private const int input_pad_right,\n"
" __private const int output_pad_left,\n"
" __private const int output_pad_right) {\n"
" \n"
" const int ow_idx=get_global_id(0);\n"
" const int b_oh_idx=get_global_id(1);\n"
" const int c_idx=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(ow_idx,b_oh_idx,c_idx);\n"
" \n"
" const int b_idx=b_oh_idx/output_shape.x;\n"
" const int oh_idx=b_oh_idx % output_shape.x;\n"
" const int iw_start=mad24(ow_idx,STRIDE_X,-pad_shape.y);\n"
" const int ih_start=mad24(oh_idx,STRIDE_Y,-pad_shape.x);\n"
" const int dst_width=output_shape.y+output_pad_left+output_pad_right;\n"
" \n"
" #ifdef POOL_AVG\n"
" COMPUTE_FLOAT4 result=(COMPUTE_FLOAT4)(0);\n"
" const int inp_offset=(((b_idx+c_idx*batch)*input_shape.x+ih_start)*input_shape.y+iw_start+input_pad_left)*4;\n"
" #ifdef COUNT_INCLUDE_PADDING\n"
" int total_count=(min(ih_start+KERNEL_Y,input_shape.x+pad_shape.x)-ih_start)*(min(iw_start+KERNEL_X,input_shape.y+pad_shape.y)-iw_start);\n"
"#else\n"
" int total_count=0;\n"
"#endif\n"
" for(int kh=0; kh<KERNEL_Y; kh++) {\n"
" int ih_cur=ih_start+kh;\n"
" if(ih_cur<0 || ih_cur >= input_shape.x) {\n"
" continue;\n"
" }\n"
" for(int kw=0; kw<KERNEL_X; kw++) {\n"
" int iw_cur=iw_start+kw;\n"
" if(iw_cur<0 || iw_cur >= input_shape.y) {\n"
" continue;\n"
" }\n"
" COMPUTE_FLOAT4 inp_data=CONVERT_COMPUTE_FLOAT4(vload4(0,input+inp_offset+(kh*input_shape.y+kw)*4));\n"
" result += inp_data;\n"
"#ifndef COUNT_INCLUDE_PADDING\n"
" total_count++;\n"
"#endif\n"
" }\n"
" }\n"
" result=result/(COMPUTE_FLOAT4)(1.0*total_count);\n"
" #else\n"
" COMPUTE_FLOAT4 result=(COMPUTE_FLOAT4)(-FLT_MAX);\n"
" #if RETURN_REDICE\n"
" int4 redice=(int4)0;\n"
" #endif\n"
" const int inp_offset=(((b_idx+c_idx*batch)*input_shape.x+ih_start)*input_shape.y+iw_start+input_pad_left)*4;\n"
" for(int kh=0; kh<KERNEL_Y; kh++) {\n"
" int ih_cur=ih_start+kh;\n"
" if(ih_cur<0 || ih_cur >= input_shape.x) {\n"
" continue;\n"
" }\n"
" for(int kw=0; kw<KERNEL_X; kw++) {\n"
" int iw_cur=iw_start+kw;\n"
" if(iw_cur<0 || iw_cur >= input_shape.y) {\n"
" continue;\n"
" }\n"
" COMPUTE_FLOAT4 inp_data=CONVERT_COMPUTE_FLOAT4(vload4(0,input+inp_offset+(kh*input_shape.y+kw)*4));\n"
" #if RETURN_REDICE\n"
" redice=inp_data>result ? (int4)((ih_start+kh)*input_shape.y+iw_start+kw) : redice;\n"
" #endif\n"
" result=fmax(result,inp_data);\n"
" }\n"
" }\n"
" #endif\n"
" const int c_left=(c_idx % 4)*4;\n"
" const int out_offset=(((b_idx*out_channel_block+c_idx/4)*output_shape.x+oh_idx)* dst_width+ow_idx+output_pad_left)*16+c_left;\n"
" vstore4(CONVERT_FLOAT4(result),0,output+out_offset);\n"
" #if RETURN_REDICE\n"
" vstore4(CONVERT_FLOAT4(redice),0,rediceOutput+(((b_idx*out_channel_block+c_idx)*output_shape.x+oh_idx)* output_shape.y+ow_idx)*4);\n"
" #endif\n"
" if(ow_idx == 0){\n"
" int pad_offset=(((b_idx*out_channel_block+c_idx/4)*output_shape.x+oh_idx)* dst_width+0)*16+c_left;\n"
" for(int i=0; i<output_pad_left; ++i){\n"
" vstore4((FLOAT4)0,0,output+pad_offset+i*16);\n"
" }\n"
" pad_offset += (output_shape.y+output_pad_left)*16;\n"
" for(int i=0; i<output_pad_right; ++i){\n"
" vstore4((FLOAT4)0,0,output+pad_offset+i*16);\n"
" }\n"
" }\n"
"}\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void pooling_c16_c16(GLOBAL_SIZE_3_DIMS __global const FLOAT *input,\n"
" __private const int2 input_shape,\n"
" __private const int2 output_shape,\n"
" __private const int2 pad_shape,\n"
" __global FLOAT *output,\n"
" __global FLOAT *rediceOutput,\n"
" __private const int channel,\n"
" __private const int batch,\n"
" __private const int in_channel_block,\n"
" __private const int out_channel_block,\n"
" __private const int input_pad_left,\n"
" __private const int input_pad_right,\n"
" __private const int output_pad_left,\n"
" __private const int output_pad_right) {\n"
" \n"
" const int ow_idx=get_global_id(1) << 3;\n"
" const int b_oh_idx=get_global_id(2);\n"
" const int c_idx=get_group_id(0);\n"
" const int sglid=get_sub_group_local_id();\n"
" \n"
" const int b_idx=b_oh_idx/output_shape.x;\n"
" const int oh_idx=b_oh_idx % output_shape.x;\n"
" const int iw_start=mad24(ow_idx,STRIDE_X,-pad_shape.y);\n"
" const int ih_start=mad24(oh_idx,STRIDE_Y,-pad_shape.x);\n"
" const int src_width=input_shape.y+input_pad_left+input_pad_right;\n"
" const int dst_width=output_shape.y+output_pad_left+output_pad_right;\n"
"#ifdef POOL_AVG\n"
" COMPUTE_FLOAT8 result=(COMPUTE_FLOAT8)(0);\n"
" COMPUTE_FLOAT8 w_start=(COMPUTE_FLOAT8)(iw_start,iw_start+STRIDE_X,iw_start+STRIDE_X*2,iw_start+STRIDE_X*3,iw_start+STRIDE_X*4,iw_start+STRIDE_X*5,iw_start+STRIDE_X*6,iw_start+STRIDE_X*7);\n"
"#ifdef COUNT_INCLUDE_PADDING\n"
" COMPUTE_FLOAT8 w_size=fmin(w_start+KERNEL_X,input_shape.y+pad_shape.y)-w_start;\n"
" COMPUTE_FLOAT8 total_count=(COMPUTE_FLOAT8)(min(ih_start+KERNEL_Y,input_shape.x+pad_shape.x)-ih_start)*w_size;\n"
"#else\n"
" w_start=fmax(w_start,(COMPUTE_FLOAT8)0);\n"
" COMPUTE_FLOAT8 w_end=fmin(w_start+KERNEL_X,(COMPUTE_FLOAT8)input_shape.y);\n"
" float h_start=fmax((float)ih_start,0);\n"
" float h_end=fmin(h_start+KERNEL_Y,(float)input_shape.x);\n"
" COMPUTE_FLOAT8 total_count=(w_end-w_start)*(COMPUTE_FLOAT8)(h_end-h_start);\n"
"#endif\n"
"#else\n"
" COMPUTE_FLOAT8 result=(COMPUTE_FLOAT8)(-FLT_MAX);\n"
"#if RETURN_REDICE\n"
" int8 redice=(int8)0;\n"
"#endif\n"
"#endif\n"
" const int inp_offset=mul24(mad24(mad24(mad24(b_idx,in_channel_block,c_idx),input_shape.x,ih_start),src_width,iw_start+input_pad_left),16);\n"
" for(int kh=0; kh<KERNEL_Y; kh++) {\n"
" int ih_cur=ih_start+kh;\n"
" if(ih_cur<0 || ih_cur >= input_shape.x) {\n"
" continue;\n"
" }\n"
" FLOAT line_cache[INPUT_LINE_SIZE];\n"
" for (int i=0; i<INPUT_LINE_SIZE; i++) {\n"
" if ((iw_start+i) >= 0 && (iw_start+i)<input_shape.y){\n"
"#ifdef MNN_SUPPORT_FP16\n"
" line_cache[i]=as_half(intel_sub_group_block_read_us((__global ushort*)(input+inp_offset+mul24(mad24(kh,src_width,i),16))));\n"
"#else\n"
" line_cache[i]=as_float(intel_sub_group_block_read((__global uint*)(input+inp_offset+mul24(mad24(kh,src_width,i),16))));\n"
"#endif\n"
" } else{\n"
"#ifdef POOL_AVG\n"
" line_cache[i]=0;\n"
"#else\n"
" line_cache[i]=(COMPUTE_FLOAT)(-FLT_MAX);\n"
"#endif\n"
" }\n"
" }\n"
" for(int kw=0; kw<KERNEL_X; kw++) {\n"
" COMPUTE_FLOAT8 src;\n"
" __attribute__((opencl_unroll_hint(8)))\n"
" for (int i=0; i<8; i++) {\n"
" src[i]=line_cache[kw+STRIDE_X*i];\n"
" }\n"
"#ifdef POOL_AVG\n"
" result += src;\n"
"#else\n"
"#if RETURN_REDICE\n"
" redice=src>result ? (int8)((ih_start+kh)*input_shape.y+iw_start+kw) : redice;\n"
"#endif\n"
" result=fmax(result,src);\n"
"#endif\n"
" }\n"
" }\n"
"#ifdef POOL_AVG\n"
" result=result/total_count;\n"
"#endif\n"
" if(ow_idx == 0){\n"
" int pad_offset=(((b_idx*out_channel_block+c_idx)*output_shape.x+oh_idx)* dst_width+0)*16+sglid;\n"
" for(int i=0; i<output_pad_left; ++i){\n"
" output[pad_offset+i*16]=0;\n"
" }\n"
" pad_offset += (output_shape.y+output_pad_left)*16;\n"
" for(int i=0; i<output_pad_right; ++i){\n"
" output[pad_offset+i*16]=0;\n"
" }\n"
" }\n"
" \n"
" const int out_offset=(((b_idx*out_channel_block+c_idx)*output_shape.x+oh_idx)* dst_width+ow_idx+output_pad_left)*16;\n"
"#if OUTPUT_LEFTOVERS\n"
" if ((c_idx+1)*16 >= channel) {\n"
" for (int i=0; i<8; i++) {\n"
" if ((c_idx*16+sglid<channel) && (ow_idx+i)<output_shape.y)\n"
" output[out_offset+i*16+sglid]=result[i];\n"
" }\n"
" }\n"
" else\n"
"#endif \n"
" {\n"
" if (ow_idx+8 <= output_shape.y) {\n"
"#ifdef MNN_SUPPORT_FP16\n"
" intel_sub_group_block_write_us8((__global ushort*)(output+out_offset),as_ushort8(CONVERT_FLOAT8(result)));\n"
"#else\n"
" intel_sub_group_block_write8((__global uint*)(output+out_offset),as_uint8(CONVERT_FLOAT8(result)));\n"
"#endif\n"
" }else{\n"
" for (int i=0; i<output_shape.y % 8; i++) {\n"
" output[out_offset+i*16+sglid]=result[i];\n"
" }\n"
" }\n"
" }\n"
"#ifdef RETURN_REDICE\n"
" const uint lid_x=sglid % 4;\n"
" const uint lid_y=sglid/4;\n"
" \n"
" const int width_height=output_shape.y*output_shape.x*4;\n"
" const int redice_offset=(((b_idx*out_channel_block+c_idx*4)*output_shape.x+oh_idx)* output_shape.y+ow_idx)*4;\n"
"#if OUTPUT_LEFTOVERS\n"
" if ((c_idx+1)*16 >= channel) {\n"
" for (int i=0; i<8; i++) {\n"
" if ((c_idx*16+lid_y*4+lid_x<channel) && (ow_idx+i)<output_shape.y)\n"
" rediceOutput[redice_offset+lid_y*width_height+i*4+lid_x]=redice[i];\n"
" }\n"
" }\n"
" else\n"
"#endif\n"
" {\n"
" for (int i=0; i<8 && (ow_idx+i)<output_shape.y; i++) {\n"
" rediceOutput[redice_offset+lid_y*width_height+i*4+lid_x]=redice[i];\n"
" }\n"
" }\n"
"#endif\n"
"}\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void pooling_c16_c4(GLOBAL_SIZE_3_DIMS __global const FLOAT *input,\n"
" __private const int2 input_shape,\n"
" __private const int2 output_shape,\n"
" __private const int2 pad_shape,\n"
" __global FLOAT *output,\n"
" __global FLOAT *rediceOutput,\n"
" __private const int channel,\n"
" __private const int batch,\n"
" __private const int in_channel_block,\n"
" __private const int out_channel_block,\n"
" __private const int input_pad_left,\n"
" __private const int input_pad_right,\n"
" __private const int output_pad_left,\n"
" __private const int output_pad_right) {\n"
" \n"
" const int ow_idx=get_global_id(1) << 3;\n"
" const int b_oh_idx=get_global_id(2);\n"
" const int c_idx=get_group_id(0);\n"
" const int sglid=get_sub_group_local_id();\n"
" \n"
" const int b_idx=b_oh_idx/output_shape.x;\n"
" const int oh_idx=b_oh_idx % output_shape.x;\n"
" const int iw_start=mad24(ow_idx,STRIDE_X,-pad_shape.y);\n"
" const int ih_start=mad24(oh_idx,STRIDE_Y,-pad_shape.x);\n"
" const int src_width=input_shape.y+input_pad_left+input_pad_right;\n"
"#ifdef POOL_AVG\n"
" COMPUTE_FLOAT8 result=(COMPUTE_FLOAT8)(0);\n"
" COMPUTE_FLOAT8 w_start=(COMPUTE_FLOAT8)(iw_start,iw_start+STRIDE_X,iw_start+STRIDE_X*2,iw_start+STRIDE_X*3,iw_start+STRIDE_X*4,iw_start+STRIDE_X*5,iw_start+STRIDE_X*6,iw_start+STRIDE_X*7);\n"
"#ifdef COUNT_INCLUDE_PADDING\n"
" COMPUTE_FLOAT8 w_size=fmin(w_start+KERNEL_X,input_shape.y+pad_shape.y)-w_start;\n"
" COMPUTE_FLOAT8 total_count=(COMPUTE_FLOAT8)(min(ih_start+KERNEL_Y,input_shape.x+pad_shape.x)-ih_start)*w_size;\n"
"#else\n"
" w_start=fmax(w_start,(COMPUTE_FLOAT8)0);\n"
" COMPUTE_FLOAT8 w_end=fmin(w_start+KERNEL_X,(COMPUTE_FLOAT8)input_shape.y);\n"
" float h_start=fmax((float)ih_start,0);\n"
" float h_end=fmin(h_start+KERNEL_Y,(float)input_shape.x);\n"
" COMPUTE_FLOAT8 total_count=(w_end-w_start)*(COMPUTE_FLOAT8)(h_end-h_start);\n"
"#endif\n"
"#else\n"
" COMPUTE_FLOAT8 result=(COMPUTE_FLOAT8)(-FLT_MAX);\n"
"#if RETURN_REDICE\n"
" int8 redice=(int8)0;\n"
"#endif\n"
"#endif\n"
" const int inp_offset=mul24(mad24(mad24(mad24(b_idx,in_channel_block,c_idx),input_shape.x,ih_start),src_width,iw_start+input_pad_left),16);\n"
" for(int kh=0; kh<KERNEL_Y; kh++) {\n"
" int ih_cur=ih_start+kh;\n"
" if(ih_cur<0 || ih_cur >= input_shape.x) {\n"
" continue;\n"
" }\n"
" FLOAT line_cache[INPUT_LINE_SIZE];\n"
" for (int i=0; i<INPUT_LINE_SIZE; i++) {\n"
" if ((iw_start+i) >= 0 && (iw_start+i)<input_shape.y){\n"
"#ifdef MNN_SUPPORT_FP16\n"
" line_cache[i]=as_half(intel_sub_group_block_read_us((__global ushort*)(input+inp_offset+mul24(mad24(kh,src_width,i),16))));\n"
"#else\n"
" line_cache[i]=as_float(intel_sub_group_block_read((__global uint*)(input+inp_offset+mul24(mad24(kh,src_width,i),16))));\n"
"#endif\n"
" } else{\n"
"#ifdef POOL_AVG\n"
" line_cache[i]=0;\n"
"#else\n"
" line_cache[i]=(FLOAT)(-FLT_MAX);\n"
"#endif\n"
" }\n"
" }\n"
" for(int kw=0; kw<KERNEL_X; kw++) {\n"
" COMPUTE_FLOAT8 src;\n"
" __attribute__((opencl_unroll_hint(8)))\n"
" for (int i=0; i<8; i++) {\n"
" src[i]=line_cache[kw+STRIDE_X*i];\n"
" }\n"
"#ifdef POOL_AVG\n"
" result += src;\n"
"#else\n"
"#if RETURN_REDICE\n"
" redice=src>result ? (int8)((ih_start+kh)*input_shape.y+iw_start+kw) : redice;\n"
"#endif\n"
" result=fmax(result,src);\n"
"#endif\n"
" }\n"
" }\n"
"#ifdef POOL_AVG\n"
" result=result/total_count;\n"
"#endif\n"
" const uint lid_x=sglid % 4;\n"
" const uint lid_y=sglid/4;\n"
" \n"
" const int out_offset=(((b_idx+c_idx*4*batch)*output_shape.x+oh_idx)* output_shape.y+ow_idx+output_pad_left)*4;\n"
" const int batch_width_height=batch*output_shape.y*output_shape.x*4;\n"
"#if RETURN_REDICE\n"
" const int redice_offset=(((b_idx+c_idx*4*batch)*output_shape.x+oh_idx)* output_shape.y+ow_idx)*4;\n"
"#endif\n"
"#if OUTPUT_LEFTOVERS\n"
" if ((c_idx+1)*16 >= channel) {\n"
" for (int i=0; i<8; i++) {\n"
" if ((c_idx*16+lid_y*4+lid_x<channel) && (ow_idx+i)<output_shape.y)\n"
" output[out_offset+lid_y*batch_width_height+i*4+lid_x]=result[i];\n"
"#if RETURN_REDICE\n"
" rediceOutput[redice_offset+lid_y*batch_width_height+i*4+lid_x]=redice[i];\n"
"#endif\n"
" }\n"
" }\n"
" else\n"
"#endif \n"
" {\n"
" for (int i=0; i<8 && (ow_idx+i)<output_shape.y; i++) {\n"
" output[out_offset+lid_y*batch_width_height+i*4+lid_x]=result[i];\n"
"#if RETURN_REDICE\n"
" rediceOutput[redice_offset+lid_y*batch_width_height+i*4+lid_x]=redice[i];\n"
"#endif\n"
" }\n"
" }\n"
"}\n"
;
#endif
#endif
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* pooling_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"__kernel void pooling(GLOBAL_SIZE_3_DIMS __global const FLOAT *input,\n"
" __private const int2 input_shape,\n"
" __private const int2 output_shape,\n"
" __private const int2 pad_shape,\n"
" __private const int2 stride_shape,\n"
" __private const int2 kernel_shape,\n"
" __global FLOAT *output,\n"
" __global FLOAT *rediceOutput,\n"
" __private const int batch) {\n"
" \n"
" const int ow_idx=get_global_id(0);\n"
" const int b_oh_idx=get_global_id(1);\n"
" const int c_idx=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(ow_idx,b_oh_idx,c_idx);\n"
" \n"
" const int b_idx=b_oh_idx/output_shape.x;\n"
" const int oh_idx=b_oh_idx % output_shape.x;\n"
" const int iw_start=mad24(ow_idx,stride_shape.y,-pad_shape.y);\n"
" const int ih_start=mad24(oh_idx,stride_shape.x,-pad_shape.x);\n"
" \n"
" #ifdef POOL_AVG\n"
" COMPUTE_FLOAT4 result=(COMPUTE_FLOAT4)(0);\n"
" const int inp_offset=(((b_idx+c_idx*batch)*input_shape.x+ih_start)*input_shape.y+iw_start)*4;\n"
" #ifdef COUNT_INCLUDE_PADDING\n"
" int total_count=(min(ih_start+kernel_shape.x,input_shape.x+pad_shape.x)-ih_start)*(min(iw_start+kernel_shape.y,input_shape.y+pad_shape.y)-iw_start);\n"
" #else\n"
" int total_count=0;\n"
" #endif\n"
" for(int kh=0; kh<kernel_shape.x; kh++) {\n"
" int ih_cur=ih_start+kh;\n"
" if(ih_cur<0 || ih_cur >= input_shape.x) {\n"
" continue;\n"
" }\n"
" for(int kw=0; kw<kernel_shape.y; kw++) {\n"
" int iw_cur=iw_start+kw;\n"
" if(iw_cur<0 || iw_cur >= input_shape.y) {\n"
" continue;\n"
" }\n"
" COMPUTE_FLOAT4 inp_data=CONVERT_COMPUTE_FLOAT4(vload4(0,input+inp_offset+(kh*input_shape.y+kw)*4));\n"
" result += inp_data;\n"
" #ifndef COUNT_INCLUDE_PADDING\n"
" total_count++;\n"
" #endif\n"
" }\n"
" }\n"
" result=result/(COMPUTE_FLOAT4)(1.0*total_count);\n"
" #else\n"
" COMPUTE_FLOAT4 result=(COMPUTE_FLOAT4)(-FLT_MAX);\n"
" #if RETURN_REDICE\n"
" int4 redice=(int4)0;\n"
" #endif\n"
" const int inp_offset=(((b_idx+c_idx*batch)*input_shape.x+ih_start)*input_shape.y+iw_start)*4;\n"
" for(int kh=0; kh<kernel_shape.x; kh++) {\n"
" int ih_cur=ih_start+kh;\n"
" if(ih_cur<0 || ih_cur >= input_shape.x) {\n"
" continue;\n"
" }\n"
" for(int kw=0; kw<kernel_shape.y; kw++) {\n"
" int iw_cur=iw_start+kw;\n"
" if(iw_cur<0 || iw_cur >= input_shape.y) {\n"
" continue;\n"
" }\n"
" COMPUTE_FLOAT4 inp_data=CONVERT_COMPUTE_FLOAT4(vload4(0,input+inp_offset+(kh*input_shape.y+kw)*4));\n"
" #if RETURN_REDICE\n"
" redice=inp_data>result ? (int4)((ih_start+kh)*input_shape.y+iw_start+kw) : redice;\n"
" #endif\n"
" result=fmax(result,inp_data);\n"
" }\n"
" }\n"
" #endif\n"
" \n"
" const int out_offset=(((b_idx+c_idx*batch)*output_shape.x+oh_idx)* output_shape.y+ow_idx)*4;\n"
" vstore4(CONVERT_FLOAT4(result),0,output+out_offset);\n"
" #if RETURN_REDICE\n"
" vstore4(CONVERT_FLOAT4(redice),0,rediceOutput+out_offset);\n"
" #endif\n"
"}\n"
"#ifdef LOCAL_SIZE\n"
"__kernel void global_pooling_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT *input,\n"
" __private const int2 input_shape,\n"
" __private const int2 output_shape,\n"
" __private const int2 pad_shape,\n"
" __private const int2 stride_shape,\n"
" __private const int2 kernel_shape,\n"
" __global FLOAT *output,\n"
" __global FLOAT *rediceOutput,\n"
" __private const int batch) {\n"
" const int local_id=get_local_id(0);\n"
" const int output_channel_idx=get_global_id(1);\n"
" const int output_batch_idx=get_global_id(2);\n"
"#ifdef POOL_AVG\n"
" COMPUTE_FLOAT4 output_result=0;\n"
"#else\n"
" COMPUTE_FLOAT4 output_result=(COMPUTE_FLOAT4)(-FLT_MAX);\n"
"#if RETURN_REDICE\n"
" int4 redice=(int4)0;\n"
" int4 local rediceId[LOCAL_SIZE];\n"
"#endif\n"
"#endif\n"
" COMPUTE_FLOAT4 local sum[LOCAL_SIZE];\n"
" const int inp_offset=((output_batch_idx+output_channel_idx*batch)*input_shape.x)*input_shape.y*4;\n"
" const int size=input_shape.x*input_shape.y;\n"
" for(int i=local_id; i<size; i+=LOCAL_SIZE){\n"
" int w=i % input_shape.y;;\n"
" int h=i/input_shape.y;\n"
" COMPUTE_FLOAT4 in=CONVERT_COMPUTE_FLOAT4(vload4(0,input+inp_offset+(h*input_shape.y+w)*4));\n"
"#ifdef POOL_AVG\n"
" output_result += in;\n"
"#else\n"
" output_result=fmax(output_result,in);\n"
"#if RETURN_REDICE\n"
" redice=in>output_result ? (int4)(i) : redice;\n"
"#endif\n"
"#endif\n"
" }\n"
" \n"
" sum[local_id]=output_result;\n"
"#if RETURN_REDICE\n"
" rediceId[local_id]=redice;\n"
"#endif\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (local_id<i)\n"
"#ifdef POOL_AVG\n"
" sum[local_id]=sum[local_id]+sum[local_id+i];\n"
"#else\n"
" {\n"
" sum[local_id]=fmax(sum[local_id],sum[local_id+i]);\n"
"#if RETURN_REDICE\n"
" rediceId[local_id]=sum[local_id]>sum[local_id+i] ? rediceId[local_id] : rediceId[local_id+i];\n"
"#endif\n"
" }\n"
"#endif\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" output_result=sum[0];\n"
"#ifdef POOL_AVG\n"
" output_result /= (input_shape.x*input_shape.y);\n"
"#endif\n"
" const int out_offset=(output_batch_idx+output_channel_idx*batch)*4;\n"
" vstore4(CONVERT_FLOAT4(output_result),0,output+out_offset);\n"
"#if RETURN_REDICE\n"
" redice=rediceId[0];\n"
" vstore4(CONVERT_FLOAT4(redice),0,rediceOutput+out_offset);\n"
"#endif\n"
"}\n"
"#endif\n"
;
#endif
const char* winogradTransformSource2_5_1 = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void winogradTransformSource(__read_only image2d_t uInput,// 0\n"
" __write_only image2d_t uOutput,__private const int unitWidth,\n"
" __private const int unitHeight,// 3\n"
" __private const int padX,__private const int padY,\n"
" __private const int srcWidth,// 6\n"
" __private const int srcHeight,__private const int srcChannelC4,\n"
" __private const int batchOffset) {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1)); \n"
" if (pos.x<unitWidth*unitHeight && pos.y<srcChannelC4) {\n"
" int unitWidth_idx=pos.x % unitWidth;\n"
" int unitHeight_idx=pos.x/unitWidth;\n"
" int dstX=mad24(pos.y,unitWidth,unitWidth_idx);\n"
" {\n"
" int sxStart=(unitWidth_idx)*2-padX;\n"
" int syStart=(unitHeight_idx)*2-padY;\n"
" FLOAT4 S00;\n"
" FLOAT4 S10;\n"
" FLOAT4 S20;\n"
" FLOAT4 S30;\n"
" FLOAT4 S40;\n"
" FLOAT4 S50;\n"
" FLOAT4 S01;\n"
" FLOAT4 S11;\n"
" FLOAT4 S21;\n"
" FLOAT4 S31;\n"
" FLOAT4 S41;\n"
" FLOAT4 S51;\n"
" FLOAT4 S02;\n"
" FLOAT4 S12;\n"
" FLOAT4 S22;\n"
" FLOAT4 S32;\n"
" FLOAT4 S42;\n"
" FLOAT4 S52;\n"
" FLOAT4 S03;\n"
" FLOAT4 S13;\n"
" FLOAT4 S23;\n"
" FLOAT4 S33;\n"
" FLOAT4 S43;\n"
" FLOAT4 S53;\n"
" FLOAT4 S04;\n"
" FLOAT4 S14;\n"
" FLOAT4 S24;\n"
" FLOAT4 S34;\n"
" FLOAT4 S44;\n"
" FLOAT4 S54;\n"
" FLOAT4 S05;\n"
" FLOAT4 S15;\n"
" FLOAT4 S25;\n"
" FLOAT4 S35;\n"
" FLOAT4 S45;\n"
" FLOAT4 S55;\n"
" {\n"
" int sx=0+sxStart;\n"
" int sy=0+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S00=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=1+sxStart;\n"
" int sy=0+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S10=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=2+sxStart;\n"
" int sy=0+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S20=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=3+sxStart;\n"
" int sy=0+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S30=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=4+sxStart;\n"
" int sy=0+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S40=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=5+sxStart;\n"
" int sy=0+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S50=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=0+sxStart;\n"
" int sy=1+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S01=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=1+sxStart;\n"
" int sy=1+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S11=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=2+sxStart;\n"
" int sy=1+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S21=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=3+sxStart;\n"
" int sy=1+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S31=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=4+sxStart;\n"
" int sy=1+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S41=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=5+sxStart;\n"
" int sy=1+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S51=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=0+sxStart;\n"
" int sy=2+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S02=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=1+sxStart;\n"
" int sy=2+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S12=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=2+sxStart;\n"
" int sy=2+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S22=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=3+sxStart;\n"
" int sy=2+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S32=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=4+sxStart;\n"
" int sy=2+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S42=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=5+sxStart;\n"
" int sy=2+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S52=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=0+sxStart;\n"
" int sy=3+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S03=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=1+sxStart;\n"
" int sy=3+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S13=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=2+sxStart;\n"
" int sy=3+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S23=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=3+sxStart;\n"
" int sy=3+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S33=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=4+sxStart;\n"
" int sy=3+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S43=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=5+sxStart;\n"
" int sy=3+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S53=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=0+sxStart;\n"
" int sy=4+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S04=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=1+sxStart;\n"
" int sy=4+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S14=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=2+sxStart;\n"
" int sy=4+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S24=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=3+sxStart;\n"
" int sy=4+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S34=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=4+sxStart;\n"
" int sy=4+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S44=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=5+sxStart;\n"
" int sy=4+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S54=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=0+sxStart;\n"
" int sy=5+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S05=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=1+sxStart;\n"
" int sy=5+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S15=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=2+sxStart;\n"
" int sy=5+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S25=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=3+sxStart;\n"
" int sy=5+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S35=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=4+sxStart;\n"
" int sy=5+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S45=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" {\n"
" int sx=5+sxStart;\n"
" int sy=5+syStart;\n"
" int imageSx=select(sx+pos.y*srcWidth,-1,sx<0 || sx >= srcWidth);\n"
" int imageSy=select(batchOffset*srcHeight+sy,-1,sy<0 || sy >= srcHeight);\n"
" S55=RI_F(uInput,SAMPLER,(int2)(imageSx,imageSy));\n"
" }\n"
" FLOAT4 m00=+S00-(FLOAT)1.25*S02+(FLOAT)0.25*S04;\n"
" FLOAT4 m10=+S10-(FLOAT)1.25*S12+(FLOAT)0.25*S14;\n"
" FLOAT4 m20=+S20-(FLOAT)1.25*S22+(FLOAT)0.25*S24;\n"
" FLOAT4 m30=+S30-(FLOAT)1.25*S32+(FLOAT)0.25*S34;\n"
" FLOAT4 m40=+S40-(FLOAT)1.25*S42+(FLOAT)0.25*S44;\n"
" FLOAT4 m50=+S50-(FLOAT)1.25*S52+(FLOAT)0.25*S54;\n"
" FLOAT4 m01=+(FLOAT)0.666667*S01+(FLOAT)0.666667*S02-(FLOAT)0.166667*S03-(FLOAT)0.166667*S04;\n"
" FLOAT4 m11=+(FLOAT)0.666667*S11+(FLOAT)0.666667*S12-(FLOAT)0.166667*S13-(FLOAT)0.166667*S14;\n"
" FLOAT4 m21=+(FLOAT)0.666667*S21+(FLOAT)0.666667*S22-(FLOAT)0.166667*S23-(FLOAT)0.166667*S24;\n"
" FLOAT4 m31=+(FLOAT)0.666667*S31+(FLOAT)0.666667*S32-(FLOAT)0.166667*S33-(FLOAT)0.166667*S34;\n"
" FLOAT4 m41=+(FLOAT)0.666667*S41+(FLOAT)0.666667*S42-(FLOAT)0.166667*S43-(FLOAT)0.166667*S44;\n"
" FLOAT4 m51=+(FLOAT)0.666667*S51+(FLOAT)0.666667*S52-(FLOAT)0.166667*S53-(FLOAT)0.166667*S54;\n"
" FLOAT4 m02=-(FLOAT)0.666667*S01+(FLOAT)0.666667*S02+(FLOAT)0.166667*S03-(FLOAT)0.166667*S04;\n"
" FLOAT4 m12=-(FLOAT)0.666667*S11+(FLOAT)0.666667*S12+(FLOAT)0.166667*S13-(FLOAT)0.166667*S14;\n"
" FLOAT4 m22=-(FLOAT)0.666667*S21+(FLOAT)0.666667*S22+(FLOAT)0.166667*S23-(FLOAT)0.166667*S24;\n"
" FLOAT4 m32=-(FLOAT)0.666667*S31+(FLOAT)0.666667*S32+(FLOAT)0.166667*S33-(FLOAT)0.166667*S34;\n"
" FLOAT4 m42=-(FLOAT)0.666667*S41+(FLOAT)0.666667*S42+(FLOAT)0.166667*S43-(FLOAT)0.166667*S44;\n"
" FLOAT4 m52=-(FLOAT)0.666667*S51+(FLOAT)0.666667*S52+(FLOAT)0.166667*S53-(FLOAT)0.166667*S54;\n"
" FLOAT4 m03 =\n"
" -(FLOAT)0.0833333*S01-(FLOAT)0.0416667*S02+(FLOAT)0.0833333*S03+(FLOAT)0.0416667*S04;\n"
" FLOAT4 m13 =\n"
" -(FLOAT)0.0833333*S11-(FLOAT)0.0416667*S12+(FLOAT)0.0833333*S13+(FLOAT)0.0416667*S14;\n"
" FLOAT4 m23 =\n"
" -(FLOAT)0.0833333*S21-(FLOAT)0.0416667*S22+(FLOAT)0.0833333*S23+(FLOAT)0.0416667*S24;\n"
" FLOAT4 m33 =\n"
" -(FLOAT)0.0833333*S31-(FLOAT)0.0416667*S32+(FLOAT)0.0833333*S33+(FLOAT)0.0416667*S34;\n"
" FLOAT4 m43 =\n"
" -(FLOAT)0.0833333*S41-(FLOAT)0.0416667*S42+(FLOAT)0.0833333*S43+(FLOAT)0.0416667*S44;\n"
" FLOAT4 m53 =\n"
" -(FLOAT)0.0833333*S51-(FLOAT)0.0416667*S52+(FLOAT)0.0833333*S53+(FLOAT)0.0416667*S54;\n"
" FLOAT4 m04 =\n"
" +(FLOAT)0.0833333*S01-(FLOAT)0.0416667*S02-(FLOAT)0.0833333*S03+(FLOAT)0.0416667*S04;\n"
" FLOAT4 m14 =\n"
" +(FLOAT)0.0833333*S11-(FLOAT)0.0416667*S12-(FLOAT)0.0833333*S13+(FLOAT)0.0416667*S14;\n"
" FLOAT4 m24 =\n"
" +(FLOAT)0.0833333*S21-(FLOAT)0.0416667*S22-(FLOAT)0.0833333*S23+(FLOAT)0.0416667*S24;\n"
" FLOAT4 m34 =\n"
" +(FLOAT)0.0833333*S31-(FLOAT)0.0416667*S32-(FLOAT)0.0833333*S33+(FLOAT)0.0416667*S34;\n"
" FLOAT4 m44 =\n"
" +(FLOAT)0.0833333*S41-(FLOAT)0.0416667*S42-(FLOAT)0.0833333*S43+(FLOAT)0.0416667*S44;\n"
" FLOAT4 m54 =\n"
" +(FLOAT)0.0833333*S51-(FLOAT)0.0416667*S52-(FLOAT)0.0833333*S53+(FLOAT)0.0416667*S54;\n"
" FLOAT4 m05=+(FLOAT)4.0*S01-(FLOAT)5.0*S03+S05;\n"
" FLOAT4 m15=+(FLOAT)4.0*S11-(FLOAT)5.0*S13+S15;\n"
" FLOAT4 m25=+(FLOAT)4.0*S21-(FLOAT)5.0*S23+S25;\n"
" FLOAT4 m35=+(FLOAT)4.0*S31-(FLOAT)5.0*S33+S35;\n"
" FLOAT4 m45=+(FLOAT)4.0*S41-(FLOAT)5.0*S43+S45;\n"
" FLOAT4 m55=+(FLOAT)4.0*S51-(FLOAT)5.0*S53+S55;\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*0),+m00-(FLOAT)1.25*m20+(FLOAT)0.25*m40);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*1),\n"
" +(FLOAT)0.666667*m10+(FLOAT)0.666667*m20-(FLOAT)0.166667*m30-(FLOAT)0.166667*m40);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*2),\n"
" -(FLOAT)0.666667*m10+(FLOAT)0.666667*m20+(FLOAT)0.166667*m30-(FLOAT)0.166667*m40);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*3),\n"
" -(FLOAT)0.0833333*m10-(FLOAT)0.0416667*m20+(FLOAT)0.0833333*m30+(FLOAT)0.0416667*m40);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*4),\n"
" +(FLOAT)0.0833333*m10-(FLOAT)0.0416667*m20-(FLOAT)0.0833333*m30+(FLOAT)0.0416667*m40);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*5),+(FLOAT)4.0*m10-(FLOAT)5.0*m30+m50);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*6),+m01-(FLOAT)1.25*m21+(FLOAT)0.25*m41);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*7),\n"
" +(FLOAT)0.666667*m11+(FLOAT)0.666667*m21-(FLOAT)0.166667*m31-(FLOAT)0.166667*m41);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*8),\n"
" -(FLOAT)0.666667*m11+(FLOAT)0.666667*m21+(FLOAT)0.166667*m31-(FLOAT)0.166667*m41);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*9),\n"
" -(FLOAT)0.0833333*m11-(FLOAT)0.0416667*m21+(FLOAT)0.0833333*m31+(FLOAT)0.0416667*m41);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*10),\n"
" +(FLOAT)0.0833333*m11-(FLOAT)0.0416667*m21-(FLOAT)0.0833333*m31+(FLOAT)0.0416667*m41);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*11),+(FLOAT)4.0*m11-(FLOAT)5.0*m31+m51);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*12),+m02-(FLOAT)1.25*m22+(FLOAT)0.25*m42);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*13),\n"
" +(FLOAT)0.666667*m12+(FLOAT)0.666667*m22-(FLOAT)0.166667*m32-(FLOAT)0.166667*m42);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*14),\n"
" -(FLOAT)0.666667*m12+(FLOAT)0.666667*m22+(FLOAT)0.166667*m32-(FLOAT)0.166667*m42);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*15),\n"
" -(FLOAT)0.0833333*m12-(FLOAT)0.0416667*m22+(FLOAT)0.0833333*m32+(FLOAT)0.0416667*m42);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*16),\n"
" +(FLOAT)0.0833333*m12-(FLOAT)0.0416667*m22-(FLOAT)0.0833333*m32+(FLOAT)0.0416667*m42);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*17),+(FLOAT)4.0*m12-(FLOAT)5.0*m32+m52);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*18),+m03-(FLOAT)1.25*m23+(FLOAT)0.25*m43);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*19),\n"
" +(FLOAT)0.666667*m13+(FLOAT)0.666667*m23-(FLOAT)0.166667*m33-(FLOAT)0.166667*m43);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*20),\n"
" -(FLOAT)0.666667*m13+(FLOAT)0.666667*m23+(FLOAT)0.166667*m33-(FLOAT)0.166667*m43);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*21),\n"
" -(FLOAT)0.0833333*m13-(FLOAT)0.0416667*m23+(FLOAT)0.0833333*m33+(FLOAT)0.0416667*m43);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*22),\n"
" +(FLOAT)0.0833333*m13-(FLOAT)0.0416667*m23-(FLOAT)0.0833333*m33+(FLOAT)0.0416667*m43);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*23),+(FLOAT)4.0*m13-(FLOAT)5.0*m33+m53);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*24),+m04-(FLOAT)1.25*m24+(FLOAT)0.25*m44);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*25),\n"
" +(FLOAT)0.666667*m14+(FLOAT)0.666667*m24-(FLOAT)0.166667*m34-(FLOAT)0.166667*m44);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*26),\n"
" -(FLOAT)0.666667*m14+(FLOAT)0.666667*m24+(FLOAT)0.166667*m34-(FLOAT)0.166667*m44);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*27),\n"
" -(FLOAT)0.0833333*m14-(FLOAT)0.0416667*m24+(FLOAT)0.0833333*m34+(FLOAT)0.0416667*m44);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*28),\n"
" +(FLOAT)0.0833333*m14-(FLOAT)0.0416667*m24-(FLOAT)0.0833333*m34+(FLOAT)0.0416667*m44);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*29),+(FLOAT)4.0*m14-(FLOAT)5.0*m34+m54);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*30),+m05-(FLOAT)1.25*m25+(FLOAT)0.25*m45);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*31),\n"
" +(FLOAT)0.666667*m15+(FLOAT)0.666667*m25-(FLOAT)0.166667*m35-(FLOAT)0.166667*m45);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*32),\n"
" -(FLOAT)0.666667*m15+(FLOAT)0.666667*m25+(FLOAT)0.166667*m35-(FLOAT)0.166667*m45);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*33),\n"
" -(FLOAT)0.0833333*m15-(FLOAT)0.0416667*m25+(FLOAT)0.0833333*m35+(FLOAT)0.0416667*m45);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*34),\n"
" +(FLOAT)0.0833333*m15-(FLOAT)0.0416667*m25-(FLOAT)0.0833333*m35+(FLOAT)0.0416667*m45);\n"
" WI_F(uOutput,(int2)(dstX,unitHeight_idx+unitHeight*35),+(FLOAT)4.0*m15-(FLOAT)5.0*m35+m55);\n"
" }\n"
" }\n"
"}\n"
;
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* unary_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_2_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,\n"
"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n"
"inline float4 gelu(float4 in){\n"
" float4 value=0.79788458f*(0.044715f*in*in*in+in);\n"
" float4 x2=value*value;\n"
" float4 dst=value>(float4)5.0f ? (float4)1.0f : (value <= -(float4)5.0f ? -(float4)1.0f :\n"
" (value*(135135.0f+x2*(17325.0f+x2*(378.0f+x2))))/(135135.0f+x2*(62370.0f+x2*(3150.0f+x2*28.0f))));\n"
" return (1.0f+dst)*in*0.5f;\n"
"}\n"
"__kernel void unary_buf(GLOBAL_SIZE_2_DIMS\n"
" __global const INPUT_TYPE *input,\n"
" __global OUTPUT_TYPE *output,\n"
" __private const int size) {\n"
" const int x=get_global_id(0);\n"
" const int y=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(x,y);\n"
" const int offset=x << 2;\n"
"#ifdef PACK_LEAVE\n"
" if(offset+3 >= size){\n"
" int remain=size-offset;\n"
" float4 in;\n"
" float* in_ptr=(float*)&in;\n"
" for(int i=0; i<remain; ++i){\n"
" in_ptr[i]=(float)input[offset+i];\n"
" }\n"
" float4 out=OPERATOR;\n"
" float* out_ptr=(float*)&out;\n"
" for(int i=0; i<remain; ++i){\n"
" output[offset+i]=(OUTPUT_TYPE)out_ptr[i];\n"
" }\n"
" }else {\n"
"#endif\n"
" float4 in=convert_float4(vload4(0,input+offset));\n"
" float4 out=OPERATOR;\n"
" vstore4(CONVERT_OUTPUT4(out),0,output+offset);\n"
"#ifdef PACK_LEAVE\n"
" }\n"
"#endif\n"
"}\n"
;
#endif
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* depthwise_conv2d_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0,__private const int global_size_dim1,\n"
"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n"
"#define DW_CONV_NEXT_LINE_CAL(x,y) "" x = mad(inValue0, weights0, x); "" x = mad(inValue1, weights1, x); "" x = mad(inValue2, weights2, x); "" y = mad(inValue1, weights0, y); "" y = mad(inValue2, weights1, y); "" y=mad(inValue3,weights2,y);\n"
"__kernel\n"
"void depthwise_conv2d_c4h1w4(GLOBAL_SIZE_2_DIMS __global const FLOAT *input,\n"
" __global const FLOAT *filter,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT *output,\n"
" __private const int2 in_hw,\n"
" __private const int batch,\n"
" __private const int2 out_hw,\n"
" __private const int2 filter_hw,\n"
" __private const int2 pad_hw,\n"
" __private const int2 dilate_hw,\n"
" __private const int2 stride_hw,\n"
" __private const int out_w_blocks,\n"
" __private const int c_blocks) {\n"
" const int out_c_w_idx=get_global_id(0);// oc/4*ow/4\n"
" const int out_b_h_idx=get_global_id(1);// b*h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int c_idx=out_c_w_idx/out_w_blocks;\n"
" const int out_w_idx=out_c_w_idx % out_w_blocks;\n"
" const int b_idx=out_b_h_idx/out_hw.x;\n"
" const int out_h_idx=out_b_h_idx % out_hw.x;\n"
" COMPUTE_FLOAT4 outValue0=CONVERT_COMPUTE_FLOAT4(vload4(c_idx,bias));\n"
" COMPUTE_FLOAT4 outValue1=outValue0;\n"
" COMPUTE_FLOAT4 outValue2=outValue0;\n"
" COMPUTE_FLOAT4 outValue3=outValue0;\n"
" const int out_w4_idx=out_w_idx << 2;\n"
" const int in_w_start_0=out_w4_idx*stride_hw.y-pad_hw.y;\n"
" const int in_w_start_1=in_w_start_0+stride_hw.y;\n"
" const int in_w_start_2=in_w_start_1+stride_hw.y;\n"
" const int in_w_start_3=in_w_start_2+stride_hw.y;\n"
" const int in_h_start=out_h_idx*stride_hw.x-pad_hw.x;\n"
" \n"
" for (int kh=0; kh<filter_hw.x; kh++) {\n"
" const int in_h_cur=in_h_start+kh*dilate_hw.x;\n"
" if(in_h_cur<0 || in_h_cur >= in_hw.x) continue;\n"
" \n"
" int inp_offset=(((b_idx+c_idx*batch)*in_hw.x+in_h_cur)* in_hw.y+in_w_start_0)*4;\n"
" for (int kw=0; kw<filter_hw.y; kw++) {\n"
" const int filter_idx=mad24(kh,filter_hw.y,kw);\n"
" const int kw_dilate=kw*dilate_hw.y;\n"
" COMPUTE_FLOAT4 inValue0=(in_w_start_0+kw_dilate<0 || in_w_start_0+kw_dilate >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw_dilate+0,input+inp_offset));\n"
" COMPUTE_FLOAT4 inValue1=(in_w_start_1+kw_dilate<0 || in_w_start_1+kw_dilate >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw_dilate+1*stride_hw.y,input+inp_offset));\n"
" COMPUTE_FLOAT4 inValue2=(in_w_start_2+kw_dilate<0 || in_w_start_2+kw_dilate >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw_dilate+2*stride_hw.y,input+inp_offset));\n"
" COMPUTE_FLOAT4 inValue3=(in_w_start_3+kw_dilate<0 || in_w_start_3+kw_dilate >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw_dilate+3*stride_hw.y,input+inp_offset));\n"
" //NC4HW4 [1,filterShape.x*filterShape.y,1,channelBlocks] x oc4\n"
" //index: [0,filterIdx,0,inChannelBlockIdx]\n"
" COMPUTE_FLOAT4 weights=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx)*4));\n"
" outValue0=mad(inValue0,weights,outValue0);\n"
" outValue1=mad(inValue1,weights,outValue1);\n"
" outValue2=mad(inValue2,weights,outValue2);\n"
" outValue3=mad(inValue3,weights,outValue3);\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" outValue0=fmax(outValue0,(COMPUTE_FLOAT4)0);\n"
" outValue1=fmax(outValue1,(COMPUTE_FLOAT4)0);\n"
" outValue2=fmax(outValue2,(COMPUTE_FLOAT4)0);\n"
" outValue3=fmax(outValue3,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" outValue0=clamp(outValue0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" outValue1=clamp(outValue1,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" outValue2=clamp(outValue2,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" outValue3=clamp(outValue3,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" const int out_offset=(((b_idx+c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w4_idx)*4;\n"
" const int remain=out_hw.y-out_w4_idx;\n"
" if (remain >= 4) {\n"
" vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue1),1,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue2),2,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue3),3,output+out_offset);\n"
" } else if (remain == 3) {\n"
" vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue1),1,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue2),2,output+out_offset);\n"
" } else if (remain == 2) {\n"
" vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue1),1,output+out_offset);\n"
" } else if (remain == 1) {\n"
" vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n"
" }\n"
" \n"
"}\n"
"__kernel\n"
"void depthwise_conv2d_c4h1w2(GLOBAL_SIZE_2_DIMS __global const FLOAT *input,\n"
" __global const FLOAT *filter,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT *output,\n"
" __private const int2 in_hw,\n"
" __private const int batch,\n"
" __private const int2 out_hw,\n"
" __private const int2 filter_hw,\n"
" __private const int2 pad_hw,\n"
" __private const int2 dilate_hw,\n"
" __private const int2 stride_hw,\n"
" __private const int out_w_blocks,\n"
" __private const int c_blocks) {\n"
" const int out_c_w_idx=get_global_id(0);// oc/4*ow/4\n"
" const int out_b_h_idx=get_global_id(1);// b*h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int c_idx=out_c_w_idx/out_w_blocks;\n"
" const int out_w_idx=out_c_w_idx % out_w_blocks;\n"
" const int b_idx=out_b_h_idx/out_hw.x;\n"
" const int out_h_idx=out_b_h_idx % out_hw.x;\n"
" COMPUTE_FLOAT4 outValue0=CONVERT_COMPUTE_FLOAT4(vload4(c_idx,bias));\n"
" COMPUTE_FLOAT4 outValue1=outValue0;\n"
" const int out_w2_idx=out_w_idx << 1;\n"
" const int in_w_start_0=out_w2_idx*stride_hw.y-pad_hw.y;\n"
" const int in_w_start_1=in_w_start_0+stride_hw.y;\n"
" \n"
" const int in_h_start=out_h_idx*stride_hw.x-pad_hw.x;\n"
" \n"
" for (int kh=0; kh<filter_hw.x; kh++) {\n"
" const int in_h_cur=in_h_start+kh*dilate_hw.x;\n"
" if(in_h_cur<0 || in_h_cur >= in_hw.x) continue;\n"
" \n"
" int inp_offset=(((b_idx+c_idx*batch)*in_hw.x+in_h_cur)* in_hw.y+in_w_start_0)*4;\n"
" for (int kw=0; kw<filter_hw.y; kw++) {\n"
" const int filter_idx=mad24(kh,filter_hw.y,kw);\n"
" const int kw_dilate=kw*dilate_hw.y;\n"
" COMPUTE_FLOAT4 inValue0=(in_w_start_0+kw_dilate<0 || in_w_start_0+kw_dilate >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw_dilate+0,input+inp_offset));\n"
" COMPUTE_FLOAT4 inValue1=(in_w_start_1+kw_dilate<0 || in_w_start_1+kw_dilate >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw_dilate+1*stride_hw.y,input+inp_offset));\n"
" //NC4HW4 [1,filterShape.x*filterShape.y,1,channelBlocks] x oc4\n"
" //index: [0,filterIdx,0,inChannelBlockIdx]\n"
" COMPUTE_FLOAT4 weights=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx)*4));\n"
" outValue0=mad(inValue0,weights,outValue0);\n"
" outValue1=mad(inValue1,weights,outValue1);\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" outValue0=fmax(outValue0,(COMPUTE_FLOAT4)0);\n"
" outValue1=fmax(outValue1,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" outValue0=clamp(outValue0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" outValue1=clamp(outValue1,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" const int out_offset=(((b_idx+c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w2_idx)*4;\n"
" const int remain=out_hw.y-out_w2_idx;\n"
" if (remain >= 2) {\n"
" vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue1),1,output+out_offset);\n"
" } else if (remain == 1) {\n"
" vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n"
" }\n"
" \n"
"}\n"
"__kernel\n"
"void depthwise_conv2d_c4h1w1(GLOBAL_SIZE_2_DIMS __global const FLOAT *input,\n"
" __global const FLOAT *filter,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT *output,\n"
" __private const int2 in_hw,\n"
" __private const int batch,\n"
" __private const int2 out_hw,\n"
" __private const int2 filter_hw,\n"
" __private const int2 pad_hw,\n"
" __private const int2 dilate_hw,\n"
" __private const int2 stride_hw,\n"
" __private const int out_w_blocks,\n"
" __private const int c_blocks) {\n"
" const int out_c_w_idx=get_global_id(0);// oc/4*ow/4\n"
" const int out_b_h_idx=get_global_id(1);// b*h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int c_idx=out_c_w_idx/out_w_blocks;\n"
" const int out_w_idx=out_c_w_idx % out_w_blocks;\n"
" const int b_idx=out_b_h_idx/out_hw.x;\n"
" const int out_h_idx=out_b_h_idx % out_hw.x;\n"
" COMPUTE_FLOAT4 outValue0=CONVERT_COMPUTE_FLOAT4(vload4(c_idx,bias));\n"
" COMPUTE_FLOAT4 outValue1=outValue0;\n"
" const int in_w_start_0=out_w_idx*stride_hw.y-pad_hw.y;\n"
" const int in_h_start=out_h_idx*stride_hw.x-pad_hw.x;\n"
" \n"
" for (int kh=0; kh<filter_hw.x; kh++) {\n"
" const int in_h_cur=in_h_start+kh*dilate_hw.x;\n"
" if(in_h_cur<0 || in_h_cur >= in_hw.x) continue;\n"
" \n"
" int inp_offset=(((b_idx+c_idx*batch)*in_hw.x+in_h_cur)* in_hw.y+in_w_start_0)*4;\n"
" for (int kw=0; kw<filter_hw.y; kw++) {\n"
" const int filter_idx=mad24(kh,filter_hw.y,kw);\n"
" const int kw_dilate=kw*dilate_hw.y;\n"
" COMPUTE_FLOAT4 inValue0=(in_w_start_0+kw_dilate<0 || in_w_start_0+kw_dilate >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw_dilate+0,input+inp_offset));\n"
" //NC4HW4 [1,filterShape.x*filterShape.y,1,channelBlocks] x oc4\n"
" //index: [0,filterIdx,0,inChannelBlockIdx]\n"
" COMPUTE_FLOAT4 weights=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx)*4));\n"
" outValue0=mad(inValue0,weights,outValue0);\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" outValue0=fmax(outValue0,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" outValue0=clamp(outValue0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" const int out_offset=(((b_idx+c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
" vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n"
"}\n"
"__kernel\n"
"void depthwise_conv2d_s1_c8h1w4(GLOBAL_SIZE_2_DIMS __global const FLOAT *input,\n"
" __global const FLOAT *filter,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT *output,\n"
" __private const int2 in_hw,\n"
" __private const int batch,\n"
" __private const int2 out_hw,\n"
" __private const int2 filter_hw,\n"
" __private const int2 pad_hw,\n"
" __private const int2 dilate_hw,\n"
" __private const int2 stride_hw,\n"
" __private const int out_w_blocks,\n"
" __private const int c_blocks) {\n"
" const int out_c_w_idx=get_global_id(0);// oc/4*ow/4\n"
" const int out_b_h_idx=get_global_id(1);// b*h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int c_idx=(out_c_w_idx/out_w_blocks) << 1;\n"
" const int out_w_idx=out_c_w_idx % out_w_blocks;\n"
" const int b_idx=out_b_h_idx/out_hw.x;\n"
" const int out_h_idx=out_b_h_idx % out_hw.x;\n"
" COMPUTE_FLOAT4 outValue0=CONVERT_COMPUTE_FLOAT4(vload4(c_idx+0,bias));\n"
" COMPUTE_FLOAT4 outValue1=outValue0;\n"
" COMPUTE_FLOAT4 outValue2=outValue0;\n"
" COMPUTE_FLOAT4 outValue3=outValue0;\n"
" COMPUTE_FLOAT4 outValue4=CONVERT_COMPUTE_FLOAT4(vload4(c_idx+1,bias));\n"
" COMPUTE_FLOAT4 outValue5=outValue4;\n"
" COMPUTE_FLOAT4 outValue6=outValue4;\n"
" COMPUTE_FLOAT4 outValue7=outValue4;\n"
" const int out_w4_idx=out_w_idx << 2;\n"
" const int in_w_start_0=out_w4_idx-pad_hw.y;\n"
" const int in_w_start_1=in_w_start_0+1;\n"
" const int in_w_start_2=in_w_start_0+2;\n"
" const int in_w_start_3=in_w_start_0+3;\n"
" const int in_h_start=out_h_idx-pad_hw.x;\n"
" \n"
" for (int kh=0; kh<filter_hw.x; kh++) {\n"
" const int in_h_cur=in_h_start+kh;\n"
" if(in_h_cur<0 || in_h_cur >= in_hw.x) continue;\n"
" \n"
" int inp_offset_c0=(((b_idx+c_idx*batch)*in_hw.x+in_h_cur)* in_hw.y+in_w_start_0)*4;\n"
" int inp_offset_c1=(((b_idx+(c_idx+1)*batch)*in_hw.x+in_h_cur)* in_hw.y+in_w_start_0)*4;\n"
" for (int kw=0; kw<filter_hw.y; kw++) {\n"
" const int filter_idx=mad24(kh,filter_hw.y,kw);\n"
" COMPUTE_FLOAT4 inValue0=(in_w_start_0+kw<0 || in_w_start_0+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+0,input+inp_offset_c0));\n"
" COMPUTE_FLOAT4 inValue1=(in_w_start_1+kw<0 || in_w_start_1+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+1,input+inp_offset_c0));\n"
" COMPUTE_FLOAT4 inValue2=(in_w_start_2+kw<0 || in_w_start_2+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+2,input+inp_offset_c0));\n"
" COMPUTE_FLOAT4 inValue3=(in_w_start_3+kw<0 || in_w_start_3+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+3,input+inp_offset_c0));\n"
" COMPUTE_FLOAT4 inValue4=(in_w_start_0+kw<0 || in_w_start_0+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+0,input+inp_offset_c1));\n"
" COMPUTE_FLOAT4 inValue5=(in_w_start_1+kw<0 || in_w_start_1+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+1,input+inp_offset_c1));\n"
" COMPUTE_FLOAT4 inValue6=(in_w_start_2+kw<0 || in_w_start_2+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+2,input+inp_offset_c1));\n"
" COMPUTE_FLOAT4 inValue7=(in_w_start_3+kw<0 || in_w_start_3+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+3,input+inp_offset_c1));\n"
" \n"
" //NC4HW4 [1,filterShape.x*filterShape.y,1,channelBlocks] x oc4\n"
" //index: [0,filterIdx,0,inChannelBlockIdx]\n"
" COMPUTE_FLOAT4 weights_0=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx+0)*4));\n"
" COMPUTE_FLOAT4 weights_1=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx+1)*4));\n"
" outValue0=mad(inValue0,weights_0,outValue0);\n"
" outValue1=mad(inValue1,weights_0,outValue1);\n"
" outValue2=mad(inValue2,weights_0,outValue2);\n"
" outValue3=mad(inValue3,weights_0,outValue3);\n"
" \n"
" outValue4=mad(inValue4,weights_1,outValue4);\n"
" outValue5=mad(inValue5,weights_1,outValue5);\n"
" outValue6=mad(inValue6,weights_1,outValue6);\n"
" outValue7=mad(inValue7,weights_1,outValue7);\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" outValue0=fmax(outValue0,(COMPUTE_FLOAT4)0);\n"
" outValue1=fmax(outValue1,(COMPUTE_FLOAT4)0);\n"
" outValue2=fmax(outValue2,(COMPUTE_FLOAT4)0);\n"
" outValue3=fmax(outValue3,(COMPUTE_FLOAT4)0);\n"
" \n"
" outValue4=fmax(outValue4,(COMPUTE_FLOAT4)0);\n"
" outValue5=fmax(outValue5,(COMPUTE_FLOAT4)0);\n"
" outValue6=fmax(outValue6,(COMPUTE_FLOAT4)0);\n"
" outValue7=fmax(outValue7,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" outValue0=clamp(outValue0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" outValue1=clamp(outValue1,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" outValue2=clamp(outValue2,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" outValue3=clamp(outValue3,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" \n"
" outValue4=clamp(outValue4,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" outValue5=clamp(outValue5,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" outValue6=clamp(outValue6,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" outValue7=clamp(outValue7,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" int out_offset=(((b_idx+c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w4_idx)*4;\n"
" const int remain=out_hw.y-out_w4_idx;\n"
" if (remain >= 4) {\n"
" vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue1),1,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue2),2,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue3),3,output+out_offset);\n"
" } else if (remain == 3) {\n"
" vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue1),1,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue2),2,output+out_offset);\n"
" } else if (remain == 2) {\n"
" vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue1),1,output+out_offset);\n"
" } else if (remain == 1) {\n"
" vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n"
" }\n"
" \n"
" if(c_idx+1 >= c_blocks) return;\n"
" \n"
" out_offset += batch*out_hw.x*out_hw.y*4;\n"
" if (remain >= 4) {\n"
" vstore4(CONVERT_FLOAT4(outValue4),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue5),1,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue6),2,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue7),3,output+out_offset);\n"
" } else if (remain == 3) {\n"
" vstore4(CONVERT_FLOAT4(outValue4),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue5),1,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue6),2,output+out_offset);\n"
" } else if (remain == 2) {\n"
" vstore4(CONVERT_FLOAT4(outValue4),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue5),1,output+out_offset);\n"
" } else if (remain == 1) {\n"
" vstore4(CONVERT_FLOAT4(outValue4),0,output+out_offset);\n"
" }\n"
"}\n"
"__kernel\n"
"void depthwise_conv2d_s1_c8h1w2(GLOBAL_SIZE_2_DIMS __global const FLOAT *input,\n"
" __global const FLOAT *filter,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT *output,\n"
" __private const int2 in_hw,\n"
" __private const int batch,\n"
" __private const int2 out_hw,\n"
" __private const int2 filter_hw,\n"
" __private const int2 pad_hw,\n"
" __private const int2 dilate_hw,\n"
" __private const int2 stride_hw,\n"
" __private const int out_w_blocks,\n"
" __private const int c_blocks) {\n"
" const int out_c_w_idx=get_global_id(0);// oc/4*ow/4\n"
" const int out_b_h_idx=get_global_id(1);// b*h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int c_idx=(out_c_w_idx/out_w_blocks) << 1;\n"
" const int out_w_idx=out_c_w_idx % out_w_blocks;\n"
" const int b_idx=out_b_h_idx/out_hw.x;\n"
" const int out_h_idx=out_b_h_idx % out_hw.x;\n"
" COMPUTE_FLOAT4 outValue0=CONVERT_COMPUTE_FLOAT4(vload4(c_idx+0,bias));\n"
" COMPUTE_FLOAT4 outValue1=outValue0;\n"
" COMPUTE_FLOAT4 outValue4=CONVERT_COMPUTE_FLOAT4(vload4(c_idx+1,bias));\n"
" COMPUTE_FLOAT4 outValue5=outValue4;\n"
" const int out_w2_idx=out_w_idx << 1;\n"
" const int in_w_start_0=out_w2_idx-pad_hw.y;\n"
" const int in_w_start_1=in_w_start_0+1;\n"
" const int in_h_start=out_h_idx-pad_hw.x;\n"
" \n"
" for (int kh=0; kh<filter_hw.x; kh++) {\n"
" const int in_h_cur=in_h_start+kh;\n"
" if(in_h_cur<0 || in_h_cur >= in_hw.x) continue;\n"
" \n"
" int inp_offset_c0=(((b_idx+c_idx*batch)*in_hw.x+in_h_cur)* in_hw.y+in_w_start_0)*4;\n"
" int inp_offset_c1=(((b_idx+(c_idx+1)*batch)*in_hw.x+in_h_cur)* in_hw.y+in_w_start_0)*4;\n"
" for (int kw=0; kw<filter_hw.y; kw++) {\n"
" const int filter_idx=mad24(kh,filter_hw.y,kw);\n"
" COMPUTE_FLOAT4 inValue0=(in_w_start_0+kw<0 || in_w_start_0+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+0,input+inp_offset_c0));\n"
" COMPUTE_FLOAT4 inValue1=(in_w_start_1+kw<0 || in_w_start_1+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+1,input+inp_offset_c0));\n"
" COMPUTE_FLOAT4 inValue4=(in_w_start_0+kw<0 || in_w_start_0+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+0,input+inp_offset_c1));\n"
" COMPUTE_FLOAT4 inValue5=(in_w_start_1+kw<0 || in_w_start_1+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+1,input+inp_offset_c1));\n"
" //NC4HW4 [1,filterShape.x*filterShape.y,1,channelBlocks] x oc4\n"
" //index: [0,filterIdx,0,inChannelBlockIdx]\n"
" COMPUTE_FLOAT4 weights_0=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx+0)*4));\n"
" COMPUTE_FLOAT4 weights_1=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx+1)*4));\n"
" outValue0=mad(inValue0,weights_0,outValue0);\n"
" outValue1=mad(inValue1,weights_0,outValue1);\n"
" \n"
" outValue4=mad(inValue4,weights_1,outValue4);\n"
" outValue5=mad(inValue5,weights_1,outValue5);\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" outValue0=fmax(outValue0,(COMPUTE_FLOAT4)0);\n"
" outValue1=fmax(outValue1,(COMPUTE_FLOAT4)0);\n"
" \n"
" outValue4=fmax(outValue4,(COMPUTE_FLOAT4)0);\n"
" outValue5=fmax(outValue5,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" outValue0=clamp(outValue0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" outValue1=clamp(outValue1,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" \n"
" outValue4=clamp(outValue4,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" outValue5=clamp(outValue5,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" int out_offset=(((b_idx+c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w2_idx)*4;\n"
" const int remain=out_hw.y-out_w2_idx;\n"
" if (remain >= 2) {\n"
" vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue1),1,output+out_offset);\n"
" } else if (remain == 1) {\n"
" vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n"
" }\n"
" \n"
" if(c_idx+1 >= c_blocks) return;\n"
" \n"
" out_offset += batch*out_hw.x*out_hw.y*4;\n"
" if (remain >= 2) {\n"
" vstore4(CONVERT_FLOAT4(outValue4),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue5),1,output+out_offset);\n"
" } else if (remain == 1) {\n"
" vstore4(CONVERT_FLOAT4(outValue4),0,output+out_offset);\n"
" }\n"
"}\n"
"__kernel\n"
"void depthwise_conv2d_s1_c4h1w4(GLOBAL_SIZE_2_DIMS __global const FLOAT *input,\n"
" __global const FLOAT *filter,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT *output,\n"
" __private const int2 in_hw,\n"
" __private const int batch,\n"
" __private const int2 out_hw,\n"
" __private const int2 filter_hw,\n"
" __private const int2 pad_hw,\n"
" __private const int2 dilate_hw,\n"
" __private const int2 stride_hw,\n"
" __private const int out_w_blocks,\n"
" __private const int c_blocks) {\n"
" const int out_c_w_idx=get_global_id(0);// oc/4*ow/4\n"
" const int out_b_h_idx=get_global_id(1);// b*h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int c_idx=out_c_w_idx/out_w_blocks;\n"
" const int out_w_idx=out_c_w_idx % out_w_blocks;\n"
" const int b_idx=out_b_h_idx/out_hw.x;\n"
" const int out_h_idx=out_b_h_idx % out_hw.x;\n"
" COMPUTE_FLOAT4 outValue0=CONVERT_COMPUTE_FLOAT4(vload4(c_idx,bias));\n"
" COMPUTE_FLOAT4 outValue1=outValue0;\n"
" COMPUTE_FLOAT4 outValue2=outValue0;\n"
" COMPUTE_FLOAT4 outValue3=outValue0;\n"
" const int out_w4_idx=out_w_idx << 2;\n"
" const int in_w_start_0=out_w4_idx-pad_hw.y;\n"
" const int in_w_start_1=in_w_start_0+1;\n"
" const int in_w_start_2=in_w_start_0+2;\n"
" const int in_w_start_3=in_w_start_0+3;\n"
" const int in_h_start=out_h_idx-pad_hw.x;\n"
" \n"
" COMPUTE_FLOAT4 inValue0,inValue1,inValue2,inValue3;\n"
" for (int kh=0; kh<filter_hw.x; kh++) {\n"
" const int in_h_cur=in_h_start+kh;\n"
" if(in_h_cur<0 || in_h_cur >= in_hw.x) continue;\n"
" \n"
" int inp_offset=(((b_idx+c_idx*batch)*in_hw.x+in_h_cur)* in_hw.y+in_w_start_0)*4;\n"
" for (int kw=0; kw<filter_hw.y; kw++) {\n"
" const int filter_idx=mad24(kh,filter_hw.y,kw);\n"
" inValue0=(in_w_start_0+kw<0 || in_w_start_0+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+0,input+inp_offset));\n"
" inValue1=(in_w_start_1+kw<0 || in_w_start_1+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+1,input+inp_offset));\n"
" inValue2=(in_w_start_2+kw<0 || in_w_start_2+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+2,input+inp_offset));\n"
" inValue3=(in_w_start_3+kw<0 || in_w_start_3+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+3,input+inp_offset));\n"
" //NC4HW4 [1,filterShape.x*filterShape.y,1,channelBlocks] x oc4\n"
" //index: [0,filterIdx,0,inChannelBlockIdx]\n"
" COMPUTE_FLOAT4 weights=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx)*4));\n"
" outValue0=mad(inValue0,weights,outValue0);\n"
" outValue1=mad(inValue1,weights,outValue1);\n"
" outValue2=mad(inValue2,weights,outValue2);\n"
" outValue3=mad(inValue3,weights,outValue3);\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" outValue0=fmax(outValue0,(COMPUTE_FLOAT4)0);\n"
" outValue1=fmax(outValue1,(COMPUTE_FLOAT4)0);\n"
" outValue2=fmax(outValue2,(COMPUTE_FLOAT4)0);\n"
" outValue3=fmax(outValue3,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" outValue0=clamp(outValue0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" outValue1=clamp(outValue1,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" outValue2=clamp(outValue2,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" outValue3=clamp(outValue3,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" const int out_offset=(((b_idx+c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w4_idx)*4;\n"
" const int remain=out_hw.y-out_w4_idx;\n"
" if (remain >= 4) {\n"
" vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue1),1,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue2),2,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue3),3,output+out_offset);\n"
" } else if (remain == 3) {\n"
" vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue1),1,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue2),2,output+out_offset);\n"
" } else if (remain == 2) {\n"
" vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue1),1,output+out_offset);\n"
" } else if (remain == 1) {\n"
" vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n"
" }\n"
"}\n"
"__kernel\n"
"void depthwise_conv2d_k3s1p1_c4h1w2(GLOBAL_SIZE_2_DIMS __global const FLOAT *input,\n"
" __global const FLOAT *filter,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT *output,\n"
" __private const int2 in_hw,\n"
" __private const int batch,\n"
" __private const int2 out_hw,\n"
" __private const int2 filter_hw,\n"
" __private const int2 pad_hw,\n"
" __private const int2 dilate_hw,\n"
" __private const int2 stride_hw,\n"
" __private const int out_w_blocks,\n"
" __private const int c_blocks) {\n"
" const int out_c_w_idx=get_global_id(0);// oc/4*ow/2\n"
" const int out_b_h_idx=get_global_id(1);// b*h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int c_idx=out_c_w_idx/out_w_blocks;\n"
" const int out_w_idx=out_c_w_idx % out_w_blocks;\n"
" const int b_idx=out_b_h_idx/out_hw.x;\n"
" const int out_h_idx=out_b_h_idx % out_hw.x;\n"
" COMPUTE_FLOAT4 outValue0=CONVERT_COMPUTE_FLOAT4(vload4(c_idx,bias));\n"
" COMPUTE_FLOAT4 outValue1=outValue0;\n"
" const int out_w2_idx=out_w_idx << 1;\n"
" const int in_w_start_0=out_w2_idx-pad_hw.y;\n"
" const int in_h_start=out_h_idx-pad_hw.x;\n"
" COMPUTE_FLOAT4 inValue0,inValue1,inValue2,inValue3;\n"
" //first line\n"
" const int inp_offset=(((b_idx+c_idx*batch)*in_hw.x+in_h_start)* in_hw.y+in_w_start_0)*4;\n"
" inValue0=(in_h_start<0 || in_w_start_0<0 ) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0,input+inp_offset));\n"
" inValue1=(in_h_start<0 || in_w_start_0+1 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(1,input+inp_offset));\n"
" inValue2=(in_h_start<0 || in_w_start_0+2 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(2,input+inp_offset));\n"
" inValue3=(in_h_start<0 || in_w_start_0+3 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(3,input+inp_offset));\n"
" int filter_idx=mad24(0,filter_hw.y,0);\n"
" COMPUTE_FLOAT4 weights=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx)*4));\n"
" outValue0=mad(inValue0,weights,outValue0);\n"
" outValue1=mad(inValue1,weights,outValue1);\n"
" filter_idx=mad24(0,filter_hw.y,1);\n"
" weights=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx)*4));\n"
" outValue0=mad(inValue1,weights,outValue0);\n"
" outValue1=mad(inValue2,weights,outValue1);\n"
" \n"
" filter_idx=mad24(0,filter_hw.y,2);\n"
" weights=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx)*4));\n"
" outValue0=mad(inValue2,weights,outValue0);\n"
" outValue1=mad(inValue3,weights,outValue1);\n"
" //second line\n"
" inValue0=(in_h_start+1 >= in_hw.x || in_w_start_0<0 ) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_hw.y+0,input+inp_offset));\n"
" inValue1=(in_h_start+1 >= in_hw.x || in_w_start_0+1 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_hw.y+1,input+inp_offset));\n"
" inValue2=(in_h_start+1 >= in_hw.x || in_w_start_0+2 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_hw.y+2,input+inp_offset));\n"
" inValue3=(in_h_start+1 >= in_hw.x || in_w_start_0+3 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_hw.y+3,input+inp_offset));\n"
" \n"
" filter_idx=mad24(1,filter_hw.y,0);\n"
" weights=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx)*4));\n"
" outValue0=mad(inValue0,weights,outValue0);\n"
" outValue1=mad(inValue1,weights,outValue1);\n"
" filter_idx=mad24(1,filter_hw.y,1);\n"
" weights=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx)*4));\n"
" outValue0=mad(inValue1,weights,outValue0);\n"
" outValue1=mad(inValue2,weights,outValue1);\n"
" \n"
" filter_idx=mad24(1,filter_hw.y,2);\n"
" weights=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx)*4));\n"
" outValue0=mad(inValue2,weights,outValue0);\n"
" outValue1=mad(inValue3,weights,outValue1);\n"
" \n"
" //third line\n"
" inValue0=(in_h_start+2 >= in_hw.x || in_w_start_0<0 ) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(2*in_hw.y+0,input+inp_offset));\n"
" inValue1=(in_h_start+2 >= in_hw.x || in_w_start_0+1 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(2*in_hw.y+1,input+inp_offset));\n"
" inValue2=(in_h_start+2 >= in_hw.x || in_w_start_0+2 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(2*in_hw.y+2,input+inp_offset));\n"
" inValue3=(in_h_start+2 >= in_hw.x || in_w_start_0+3 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(2*in_hw.y+3,input+inp_offset));\n"
" \n"
" filter_idx=mad24(2,filter_hw.y,0);\n"
" weights=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx)*4));\n"
" outValue0=mad(inValue0,weights,outValue0);\n"
" outValue1=mad(inValue1,weights,outValue1);\n"
" filter_idx=mad24(2,filter_hw.y,1);\n"
" weights=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx)*4));\n"
" outValue0=mad(inValue1,weights,outValue0);\n"
" outValue1=mad(inValue2,weights,outValue1);\n"
" \n"
" filter_idx=mad24(2,filter_hw.y,2);\n"
" weights=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx)*4));\n"
" outValue0=mad(inValue2,weights,outValue0);\n"
" outValue1=mad(inValue3,weights,outValue1);\n"
"#ifdef RELU\n"
" outValue0=fmax(outValue0,(COMPUTE_FLOAT4)0);\n"
" outValue1=fmax(outValue1,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" outValue0=clamp(outValue0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" outValue1=clamp(outValue1,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" const int out_offset=(((b_idx+c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w2_idx)*4;\n"
" const int remain=out_hw.y-out_w2_idx;\n"
" if (remain >= 2) {\n"
" vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue1),1,output+out_offset);\n"
" } else if (remain == 1) {\n"
" vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n"
" }\n"
"}\n"
"__kernel\n"
"void depthwise_conv2d_k3s1p1_c4h2w2(GLOBAL_SIZE_2_DIMS __global const FLOAT *input,\n"
" __global const FLOAT *filter,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT *output,\n"
" __private const int2 in_hw,\n"
" __private const int batch,\n"
" __private const int2 out_hw,\n"
" __private const int2 filter_hw,\n"
" __private const int2 pad_hw,\n"
" __private const int2 dilate_hw,\n"
" __private const int2 stride_hw,\n"
" __private const int out_w_blocks,\n"
" __private const int c_blocks) {\n"
" const int out_c_w_idx=get_global_id(0);// oc/4*ow/2\n"
" const int out_b_h_idx=get_global_id(1);// b*h/2\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int out_h_blocks=(out_hw.x+1)/2;\n"
" const int c_idx=out_c_w_idx/out_w_blocks;\n"
" const int out_w_idx=out_c_w_idx % out_w_blocks;\n"
" const int b_idx=out_b_h_idx/out_h_blocks;\n"
" const int out_h_idx=out_b_h_idx % out_h_blocks;\n"
" COMPUTE_FLOAT4 outValue0=CONVERT_COMPUTE_FLOAT4(vload4(c_idx,bias));\n"
" COMPUTE_FLOAT4 outValue1=outValue0;\n"
" COMPUTE_FLOAT4 outValue2=outValue0;\n"
" COMPUTE_FLOAT4 outValue3=outValue0;\n"
" const int out_w2_idx=out_w_idx << 1;\n"
" const int in_w_start=out_w2_idx-pad_hw.y;\n"
" const int out_h2_idx=out_h_idx << 1;\n"
" const int in_h_start=out_h2_idx-pad_hw.x;\n"
" COMPUTE_FLOAT4 inValue0,inValue1,inValue2,inValue3;\n"
" //first line\n"
" const int inp_offset=(((b_idx+c_idx*batch)*in_hw.x+in_h_start)* in_hw.y+in_w_start)*4;\n"
" inValue0=(in_h_start<0 || in_w_start<0 ) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0,input+inp_offset));\n"
" inValue1=(in_h_start<0 || in_w_start+1 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(1,input+inp_offset));\n"
" inValue2=(in_h_start<0 || in_w_start+2 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(2,input+inp_offset));\n"
" inValue3=(in_h_start<0 || in_w_start+3 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(3,input+inp_offset));\n"
" int filter_idx=mad24(0,filter_hw.y,0);\n"
" COMPUTE_FLOAT4 weights0=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx)*4));\n"
" outValue0=mad(inValue0,weights0,outValue0);\n"
" outValue1=mad(inValue1,weights0,outValue1);\n"
" filter_idx=mad24(0,filter_hw.y,1);\n"
" COMPUTE_FLOAT4 weights1=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx)*4));\n"
" outValue0=mad(inValue1,weights1,outValue0);\n"
" outValue1=mad(inValue2,weights1,outValue1);\n"
" \n"
" \n"
" filter_idx=mad24(0,filter_hw.y,2);\n"
" COMPUTE_FLOAT4 weights2=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx)*4));\n"
" outValue0=mad(inValue2,weights2,outValue0);\n"
" outValue1=mad(inValue3,weights2,outValue1);\n"
" //second line\n"
" inValue0=(in_h_start+1 >= in_hw.x || in_w_start<0 ) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_hw.y+0,input+inp_offset));\n"
" inValue1=(in_h_start+1 >= in_hw.x || in_w_start+1 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_hw.y+1,input+inp_offset));\n"
" inValue2=(in_h_start+1 >= in_hw.x || in_w_start+2 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_hw.y+2,input+inp_offset));\n"
" inValue3=(in_h_start+1 >= in_hw.x || in_w_start+3 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_hw.y+3,input+inp_offset));\n"
" \n"
" DW_CONV_NEXT_LINE_CAL(outValue2,outValue3)\n"
" \n"
" filter_idx=mad24(1,filter_hw.y,0);\n"
" weights0=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx)*4));\n"
" outValue0=mad(inValue0,weights0,outValue0);\n"
" outValue1=mad(inValue1,weights0,outValue1);\n"
" filter_idx=mad24(1,filter_hw.y,1);\n"
" weights1=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx)*4));\n"
" outValue0=mad(inValue1,weights1,outValue0);\n"
" outValue1=mad(inValue2,weights1,outValue1);\n"
" \n"
" filter_idx=mad24(1,filter_hw.y,2);\n"
" weights2=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx)*4));\n"
" outValue0=mad(inValue2,weights2,outValue0);\n"
" outValue1=mad(inValue3,weights2,outValue1);\n"
" \n"
" //third line\n"
" inValue0=(in_h_start+2 >= in_hw.x || in_w_start<0 ) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(2*in_hw.y+0,input+inp_offset));\n"
" inValue1=(in_h_start+2 >= in_hw.x || in_w_start+1 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(2*in_hw.y+1,input+inp_offset));\n"
" inValue2=(in_h_start+2 >= in_hw.x || in_w_start+2 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(2*in_hw.y+2,input+inp_offset));\n"
" inValue3=(in_h_start+2 >= in_hw.x || in_w_start+3 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(2*in_hw.y+3,input+inp_offset));\n"
" \n"
" DW_CONV_NEXT_LINE_CAL(outValue2,outValue3)\n"
" \n"
" filter_idx=mad24(2,filter_hw.y,0);\n"
" weights0=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx)*4));\n"
" outValue0=mad(inValue0,weights0,outValue0);\n"
" outValue1=mad(inValue1,weights0,outValue1);\n"
" filter_idx=mad24(2,filter_hw.y,1);\n"
" weights1=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx)*4));\n"
" outValue0=mad(inValue1,weights1,outValue0);\n"
" outValue1=mad(inValue2,weights1,outValue1);\n"
" \n"
" filter_idx=mad24(2,filter_hw.y,2);\n"
" weights2=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx)*4));\n"
" outValue0=mad(inValue2,weights2,outValue0);\n"
" outValue1=mad(inValue3,weights2,outValue1);\n"
" \n"
" //forth line\n"
" inValue0=(in_h_start+3 >= in_hw.x || in_w_start<0 ) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(3*in_hw.y+0,input+inp_offset));\n"
" inValue1=(in_h_start+3 >= in_hw.x || in_w_start+1 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(3*in_hw.y+1,input+inp_offset));\n"
" inValue2=(in_h_start+3 >= in_hw.x || in_w_start+2 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(3*in_hw.y+2,input+inp_offset));\n"
" inValue3=(in_h_start+3 >= in_hw.x || in_w_start+3 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(3*in_hw.y+3,input+inp_offset));\n"
" \n"
" DW_CONV_NEXT_LINE_CAL(outValue2,outValue3)\n"
" \n"
"#ifdef RELU\n"
" outValue0=fmax(outValue0,(COMPUTE_FLOAT4)0);\n"
" outValue1=fmax(outValue1,(COMPUTE_FLOAT4)0);\n"
" outValue2=fmax(outValue2,(COMPUTE_FLOAT4)0);\n"
" outValue3=fmax(outValue3,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" outValue0=clamp(outValue0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" outValue1=clamp(outValue1,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" outValue2=clamp(outValue2,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" outValue3=clamp(outValue3,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" const int out_offset=(((b_idx+c_idx*batch)*out_hw.x+out_h2_idx)*out_hw.y+out_w2_idx)*4;\n"
" const int remain_w=out_hw.y-out_w2_idx;\n"
" const int remain_h=out_hw.x-out_h2_idx;\n"
" if(remain_w >= 2 && remain_h >= 2) {\n"
" vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue1),1,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue2),out_hw.y+0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue3),out_hw.y+1,output+out_offset);\n"
" } else if(remain_w == 1 && remain_h >= 2) {\n"
" vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue2),out_hw.y+0,output+out_offset);\n"
" } else if(remain_w >= 2 && remain_h == 1) {\n"
" vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(outValue1),1,output+out_offset);\n"
" } else {\n"
" vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n"
" }\n"
"}\n"
;
#endif
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* winogradTransform_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_DIM2 "" __private int global_size_dim0,__private int global_size_dim1,\n"
"#define UNIFORM_BOUNDRY_CHECK(index0, index1) "" if(index0 >= global_size_dim0 || index1 >= global_size_dim1) { "" return; "" }\n"
"// [dstChannel,srcChannel,3,3] -> [4x4,srcChannelPad,dstChannelpad] (N,Kpad,Npad)\n"
"__kernel void winoTransWeightBuf2_3_1(GLOBAL_SIZE_DIM2\n"
" __global const float* input,// 0\n"
" __global FLOAT* output,\n"
" __private const int srcChannel,// 3\n"
" __private const int dstChannel,\n"
" __private const int srcChannelPad,// 6\n"
" __private const int dstChannelPad\n"
") {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1));\n"
" UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n"
" \n"
" const int src_c=pos.x;\n"
" const int dst_c=pos.y;\n"
" \n"
" const int out_offset=(0*srcChannelPad+src_c)*dstChannelPad+dst_c;\n"
" const int out_offset_add=srcChannelPad*dstChannelPad;\n"
" if(src_c >= srcChannel || dst_c >= dstChannel) {\n"
" for(int i=0; i<16; i++) {\n"
" output[out_offset+i*out_offset_add]=(FLOAT)0;\n"
" }\n"
" return;\n"
" }\n"
" \n"
" const int in_offset=(dst_c*srcChannel+src_c)*9;\n"
" FLOAT8 in=CONVERT_FLOAT8(vload8(0,input+in_offset));\n"
" FLOAT in8=input[in_offset+8];\n"
" \n"
" FLOAT GB_00=in.s0;\n"
" FLOAT GB_01=in.s1;\n"
" FLOAT GB_02=in.s2;\n"
" FLOAT GB_10=in.s0+in.s3+in.s6;\n"
" FLOAT GB_11=in.s1+in.s4+in.s7;\n"
" FLOAT GB_12=in.s2+in.s5+in8;\n"
" FLOAT GB_20=in.s0-in.s3+in.s6;\n"
" FLOAT GB_21=in.s1-in.s4+in.s7;\n"
" FLOAT GB_22=in.s2-in.s5+in8;\n"
" FLOAT GB_30=in.s6;\n"
" FLOAT GB_31=in.s7;\n"
" FLOAT GB_32=in8;\n"
" \n"
" FLOAT GBGT_00=GB_00;\n"
" FLOAT GBGT_01=GB_00+GB_01+GB_02;\n"
" FLOAT GBGT_02=GB_00-GB_01+GB_02;\n"
" FLOAT GBGT_03=GB_02;\n"
" \n"
" FLOAT GBGT_10=GB_10;\n"
" FLOAT GBGT_11=GB_10+GB_11+GB_12;\n"
" FLOAT GBGT_12=GB_10-GB_11+GB_12;\n"
" FLOAT GBGT_13=GB_12;\n"
" \n"
" FLOAT GBGT_20=GB_20;\n"
" FLOAT GBGT_21=GB_20+GB_21+GB_22;\n"
" FLOAT GBGT_22=GB_20-GB_21+GB_22;\n"
" FLOAT GBGT_23=GB_22;\n"
" \n"
" FLOAT GBGT_30=GB_30;\n"
" FLOAT GBGT_31=GB_30+GB_31+GB_32;\n"
" FLOAT GBGT_32=GB_30-GB_31+GB_32;\n"
" FLOAT GBGT_33=GB_32;\n"
" output[out_offset+0*out_offset_add]=GBGT_00;\n"
" output[out_offset+1*out_offset_add]=GBGT_01;\n"
" output[out_offset+2*out_offset_add]=GBGT_02;\n"
" output[out_offset+3*out_offset_add]=GBGT_03;\n"
" output[out_offset+4*out_offset_add]=GBGT_10;\n"
" output[out_offset+5*out_offset_add]=GBGT_11;\n"
" output[out_offset+6*out_offset_add]=GBGT_12;\n"
" output[out_offset+7*out_offset_add]=GBGT_13;\n"
" output[out_offset+8*out_offset_add]=GBGT_20;\n"
" output[out_offset+9*out_offset_add]=GBGT_21;\n"
" output[out_offset+10*out_offset_add]=GBGT_22;\n"
" output[out_offset+11*out_offset_add]=GBGT_23;\n"
" output[out_offset+12*out_offset_add]=GBGT_30;\n"
" output[out_offset+13*out_offset_add]=GBGT_31;\n"
" output[out_offset+14*out_offset_add]=GBGT_32;\n"
" output[out_offset+15*out_offset_add]=GBGT_33;\n"
"}\n"
"__kernel void winoTransSrcBuf2_3_1(GLOBAL_SIZE_DIM2\n"
" __global const FLOAT* uInput,// 0\n"
" __global FLOAT* uOutput,__private const int unitWidth,\n"
" __private const int unitHeight,// 3\n"
" __private const int padX,__private const int padY,\n"
" __private const int srcWidth,// 6\n"
" __private const int srcHeight,__private const int srcChannelC4,\n"
" __private const int dstHeightPad,__private const int srcChannelPad,\n"
" __private const int batch,\n"
" __private const int batchOffset) {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1)); \n"
" UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n"
" \n"
" if(pos.x >= unitWidth*unitHeight || pos.y >= srcChannelC4) {\n"
" return;\n"
" }\n"
" int unitWidth_idx=pos.x % unitWidth;\n"
" int unitHeight_idx=pos.x/unitWidth;\n"
" int2 realPos=(int2)(unitWidth_idx,unitHeight_idx);\n"
" int dstXOrigin=pos.y;\n"
" int batchIndex=pos.y/srcChannelC4;\n"
" int srcZ=pos.y % srcChannelC4;\n"
" int dstYOrigin=unitWidth*unitHeight_idx+unitWidth_idx;\n"
" batchIndex=batchOffset;\n"
" {\n"
" int sxStart=(realPos.x)*2-padX;\n"
" int syStart=(realPos.y)*2-padY;\n"
" FLOAT4 S00;\n"
" FLOAT4 S10;\n"
" FLOAT4 S20;\n"
" FLOAT4 S30;\n"
" FLOAT4 S01;\n"
" FLOAT4 S11;\n"
" FLOAT4 S21;\n"
" FLOAT4 S31;\n"
" FLOAT4 S02;\n"
" FLOAT4 S12;\n"
" FLOAT4 S22;\n"
" FLOAT4 S32;\n"
" FLOAT4 S03;\n"
" FLOAT4 S13;\n"
" FLOAT4 S23;\n"
" FLOAT4 S33;\n"
" \n"
" int inp_offset=(((batchIndex+srcZ*batch)*srcHeight+syStart)*srcWidth+sxStart)*4;\n"
" {\n"
" int sx=0+sxStart;\n"
" int sy=0+syStart;\n"
" \n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S00=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset);\n"
" }\n"
" {\n"
" int sx=1+sxStart;\n"
" int sy=0+syStart;\n"
" \n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S10=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+4);\n"
" }\n"
" {\n"
" int sx=2+sxStart;\n"
" int sy=0+syStart;\n"
" \n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S20=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+8);\n"
" }\n"
" {\n"
" int sx=3+sxStart;\n"
" int sy=0+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S30=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+12);\n"
" }\n"
" {\n"
" int sx=0+sxStart;\n"
" int sy=1+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S01=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+4*srcWidth);\n"
" }\n"
" {\n"
" int sx=1+sxStart;\n"
" int sy=1+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S11=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+4*srcWidth+4);\n"
" }\n"
" {\n"
" int sx=2+sxStart;\n"
" int sy=1+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S21=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+4*srcWidth+8);\n"
" }\n"
" {\n"
" int sx=3+sxStart;\n"
" int sy=1+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S31=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+4*srcWidth+12);\n"
" }\n"
" {\n"
" int sx=0+sxStart;\n"
" int sy=2+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S02=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+8*srcWidth);\n"
" }\n"
" {\n"
" int sx=1+sxStart;\n"
" int sy=2+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S12=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+8*srcWidth+4);\n"
" }\n"
" {\n"
" int sx=2+sxStart;\n"
" int sy=2+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S22=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+8*srcWidth+8);\n"
" }\n"
" {\n"
" int sx=3+sxStart;\n"
" int sy=2+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S32=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+8*srcWidth+12);\n"
" }\n"
" {\n"
" int sx=0+sxStart;\n"
" int sy=3+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S03=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+12*srcWidth);\n"
" }\n"
" {\n"
" int sx=1+sxStart;\n"
" int sy=3+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S13=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+12*srcWidth+4);\n"
" }\n"
" {\n"
" int sx=2+sxStart;\n"
" int sy=3+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S23=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+12*srcWidth+8);\n"
" }\n"
" {\n"
" int sx=3+sxStart;\n"
" int sy=3+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S33=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+12*srcWidth+12);\n"
" }\n"
" FLOAT4 m00=+S00-S02;\n"
" FLOAT4 m10=+S10-S12;\n"
" FLOAT4 m20=+S20-S22;\n"
" FLOAT4 m30=+S30-S32;\n"
" FLOAT4 m01=+(FLOAT)0.5f*S01+(FLOAT)0.5f*S02;\n"
" FLOAT4 m11=+(FLOAT)0.5f*S11+(FLOAT)0.5f*S12;\n"
" FLOAT4 m21=+(FLOAT)0.5f*S21+(FLOAT)0.5f*S22;\n"
" FLOAT4 m31=+(FLOAT)0.5f*S31+(FLOAT)0.5f*S32;\n"
" FLOAT4 m02=-(FLOAT)0.5f*S01+(FLOAT)0.5f*S02;\n"
" FLOAT4 m12=-(FLOAT)0.5f*S11+(FLOAT)0.5f*S12;\n"
" FLOAT4 m22=-(FLOAT)0.5f*S21+(FLOAT)0.5f*S22;\n"
" FLOAT4 m32=-(FLOAT)0.5f*S31+(FLOAT)0.5f*S32;\n"
" FLOAT4 m03=-S01+S03;\n"
" FLOAT4 m13=-S11+S13;\n"
" FLOAT4 m23=-S21+S23;\n"
" FLOAT4 m33=-S31+S33;\n"
" \n"
" //NC4HW4 [alpha*alpha,srcChannelPad,dstHeightPad]\n"
" //index: [0,dstXOrigin,dstY,dstYOrigin % 4]\n"
" int out_offset=(0*srcChannelPad+4*dstXOrigin)*dstHeightPad+dstYOrigin;\n"
" int batch_offset=srcChannelPad*dstHeightPad;\n"
" \n"
" FLOAT4 res=(+m00-m20);\n"
" uOutput[out_offset]=res.x;\n"
" uOutput[out_offset+dstHeightPad]=res.y;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad]=res.z;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad+dstHeightPad]=res.w;\n"
" out_offset += batch_offset;\n"
" res=(+(FLOAT)0.5f*m10+(FLOAT)0.5f*m20);\n"
" uOutput[out_offset]=res.x;\n"
" uOutput[out_offset+dstHeightPad]=res.y;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad]=res.z;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad+dstHeightPad]=res.w;\n"
" \n"
" out_offset += batch_offset;\n"
" res=(-(FLOAT)0.5f*m10+(FLOAT)0.5f*m20);\n"
" uOutput[out_offset]=res.x;\n"
" uOutput[out_offset+dstHeightPad]=res.y;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad]=res.z;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad+dstHeightPad]=res.w;\n"
" \n"
" out_offset += batch_offset;\n"
" res=(-m10+m30);\n"
" uOutput[out_offset]=res.x;\n"
" uOutput[out_offset+dstHeightPad]=res.y;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad]=res.z;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad+dstHeightPad]=res.w;\n"
" \n"
" \n"
" out_offset += batch_offset;\n"
" res=(+m01-m21);\n"
" uOutput[out_offset]=res.x;\n"
" uOutput[out_offset+dstHeightPad]=res.y;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad]=res.z;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad+dstHeightPad]=res.w;\n"
" \n"
" out_offset += batch_offset;\n"
" res=(+(FLOAT)0.5f*m11+(FLOAT)0.5f*m21);\n"
" uOutput[out_offset]=res.x;\n"
" uOutput[out_offset+dstHeightPad]=res.y;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad]=res.z;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad+dstHeightPad]=res.w;\n"
" \n"
" out_offset += batch_offset;\n"
" res=(-(FLOAT)0.5f*m11+(FLOAT)0.5f*m21);\n"
" uOutput[out_offset]=res.x;\n"
" uOutput[out_offset+dstHeightPad]=res.y;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad]=res.z;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad+dstHeightPad]=res.w;\n"
" \n"
" out_offset += batch_offset;\n"
" res=(-m11+m31);\n"
" uOutput[out_offset]=res.x;\n"
" uOutput[out_offset+dstHeightPad]=res.y;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad]=res.z;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad+dstHeightPad]=res.w;\n"
" \n"
" out_offset += batch_offset;\n"
" res=(+m02-m22);\n"
" uOutput[out_offset]=res.x;\n"
" uOutput[out_offset+dstHeightPad]=res.y;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad]=res.z;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad+dstHeightPad]=res.w;\n"
" \n"
" out_offset += batch_offset;\n"
" res=(+(FLOAT)0.5f*m12+(FLOAT)0.5f*m22);\n"
" uOutput[out_offset]=res.x;\n"
" uOutput[out_offset+dstHeightPad]=res.y;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad]=res.z;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad+dstHeightPad]=res.w;\n"
" \n"
" out_offset += batch_offset;\n"
" res=(-(FLOAT)0.5f*m12+(FLOAT)0.5f*m22);\n"
" uOutput[out_offset]=res.x;\n"
" uOutput[out_offset+dstHeightPad]=res.y;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad]=res.z;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad+dstHeightPad]=res.w;\n"
" \n"
" out_offset += batch_offset;\n"
" res=(-m12+m32);\n"
" uOutput[out_offset]=res.x;\n"
" uOutput[out_offset+dstHeightPad]=res.y;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad]=res.z;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad+dstHeightPad]=res.w;\n"
" \n"
" out_offset += batch_offset;\n"
" res=(+m03-m23);\n"
" uOutput[out_offset]=res.x;\n"
" uOutput[out_offset+dstHeightPad]=res.y;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad]=res.z;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad+dstHeightPad]=res.w;\n"
" \n"
" out_offset += batch_offset;\n"
" res=(+(FLOAT)0.5f*m13+(FLOAT)0.5f*m23);\n"
" uOutput[out_offset]=res.x;\n"
" uOutput[out_offset+dstHeightPad]=res.y;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad]=res.z;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad+dstHeightPad]=res.w;\n"
" \n"
" out_offset += batch_offset;\n"
" res=(-(FLOAT)0.5f*m13+(FLOAT)0.5f*m23);\n"
" uOutput[out_offset]=res.x;\n"
" uOutput[out_offset+dstHeightPad]=res.y;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad]=res.z;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad+dstHeightPad]=res.w;\n"
" \n"
" out_offset += batch_offset;\n"
" res=(-m13+m33);\n"
" uOutput[out_offset]=res.x;\n"
" uOutput[out_offset+dstHeightPad]=res.y;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad]=res.z;\n"
" uOutput[out_offset+dstHeightPad+dstHeightPad+dstHeightPad]=res.w;\n"
" }\n"
"}\n"
"__kernel void winoTransDstBuf2_3_1(GLOBAL_SIZE_DIM2\n"
" __global const FLOAT* uInput,\n"
" __global const FLOAT* uBias,\n"
" __global FLOAT* uOutput,\n"
" __private const int unitWidth,//wUnit\n"
" __private const int unitHeight,//hUnit\n"
" __private const int dstWidth,\n"
" __private const int dstHeight,\n"
" __private const int dstChannelC4,\n"
" __private const int srcWidthPad,\n"
" __private const int dstChannelPad,\n"
" __private const int batch,\n"
" __private const int batchOffset) {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1));\n"
" UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n"
" int unitWidth_idx=pos.x % unitWidth;\n"
" int unitHeight_idx=pos.x/unitWidth;\n"
" int2 realPos=(int2)(unitWidth_idx,unitHeight_idx);\n"
" int dstXOrigin=unitWidth*unitHeight_idx+unitWidth_idx;\n"
" int oz=pos.y % dstChannelC4;\n"
" \n"
" FLOAT4 bias=vload4(0,uBias+oz*4);\n"
" int batchIndex=pos.y/dstChannelC4;\n"
" batchIndex=batchOffset;\n"
" {\n"
" int oyStart=realPos.y*2;\n"
" int oxStart=realPos.x*2;\n"
" \n"
" // [alpha2,srcWidthPad,dstChannelPad]\n"
" //index: [0,dstXOrigin,4*oz]\n"
" const int inp_offset=(0*srcWidthPad+dstXOrigin)*dstChannelPad+4*oz;\n"
" const int b_offset=dstChannelPad*srcWidthPad;\n"
" FLOAT4 S00=vload4(0,uInput+inp_offset+b_offset*0);\n"
" FLOAT4 S10=vload4(0,uInput+inp_offset+b_offset*1);\n"
" FLOAT4 S20=vload4(0,uInput+inp_offset+b_offset*2);\n"
" FLOAT4 S30=vload4(0,uInput+inp_offset+b_offset*3);\n"
" FLOAT4 S01=vload4(0,uInput+inp_offset+b_offset*4);\n"
" FLOAT4 S11=vload4(0,uInput+inp_offset+b_offset*5);\n"
" FLOAT4 S21=vload4(0,uInput+inp_offset+b_offset*6);\n"
" FLOAT4 S31=vload4(0,uInput+inp_offset+b_offset*7);\n"
" FLOAT4 S02=vload4(0,uInput+inp_offset+b_offset*8);\n"
" FLOAT4 S12=vload4(0,uInput+inp_offset+b_offset*9);\n"
" FLOAT4 S22=vload4(0,uInput+inp_offset+b_offset*10);\n"
" FLOAT4 S32=vload4(0,uInput+inp_offset+b_offset*11);\n"
" FLOAT4 S03=vload4(0,uInput+inp_offset+b_offset*12);\n"
" FLOAT4 S13=vload4(0,uInput+inp_offset+b_offset*13);\n"
" FLOAT4 S23=vload4(0,uInput+inp_offset+b_offset*14);\n"
" FLOAT4 S33=vload4(0,uInput+inp_offset+b_offset*15);\n"
" FLOAT4 m00=+S00+S01+S02;\n"
" FLOAT4 m10=+S10+S11+S12;\n"
" FLOAT4 m20=+S20+S21+S22;\n"
" FLOAT4 m30=+S30+S31+S32;\n"
" FLOAT4 m01=+S01-S02+S03;\n"
" FLOAT4 m11=+S11-S12+S13;\n"
" FLOAT4 m21=+S21-S22+S23;\n"
" FLOAT4 m31=+S31-S32+S33;\n"
" \n"
" //NC4HW4 [batch,dstChannelC4,dstHeight,dstWidth]\n"
" //index: [batchIndex,oz,oyStart,oxStart]\n"
" int out_offset=(((batchIndex+oz*batch)*dstHeight+oyStart)*dstWidth+oxStart)*4;\n"
" {\n"
" int ox=oxStart+0;\n"
" int oy=oyStart+0;\n"
" if (ox<dstWidth && oy<dstHeight) {\n"
" FLOAT4 res=bias+m00+m10+m20;\n"
"#ifdef RELU\n"
" res=max(res,(FLOAT4)(0));\n"
"#endif\n"
"#ifdef RELU6\n"
" res=clamp(res,(FLOAT4)(0),(FLOAT4)(6));\n"
"#endif\n"
" vstore4(res,0,uOutput+out_offset);\n"
" }\n"
" }\n"
" {\n"
" int ox=oxStart+1;\n"
" int oy=oyStart+0;\n"
" if (ox<dstWidth && oy<dstHeight) {\n"
" FLOAT4 res=bias+m10-m20+m30;\n"
"#ifdef RELU\n"
" res=max(res,(FLOAT4)(0));\n"
"#endif\n"
"#ifdef RELU6\n"
" res=clamp(res,(FLOAT4)(0),(FLOAT4)(6));\n"
"#endif\n"
" vstore4(res,0,uOutput+out_offset+4);\n"
" }\n"
" }\n"
" {\n"
" int ox=oxStart+0;\n"
" int oy=oyStart+1;\n"
" if (ox<dstWidth && oy<dstHeight) {\n"
" FLOAT4 res=bias+m01+m11+m21;\n"
"#ifdef RELU\n"
" res=max(res,(FLOAT4)(0));\n"
"#endif\n"
"#ifdef RELU6\n"
" res=clamp(res,(FLOAT4)(0),(FLOAT4)(6));\n"
"#endif\n"
" vstore4(res,0,uOutput+out_offset+4*dstWidth);\n"
" }\n"
" }\n"
" {\n"
" int ox=oxStart+1;\n"
" int oy=oyStart+1;\n"
" if (ox<dstWidth && oy<dstHeight) {\n"
" FLOAT4 res=bias+m11-m21+m31;\n"
"#ifdef RELU\n"
" res=max(res,(FLOAT4)(0));\n"
"#endif\n"
"#ifdef RELU6\n"
" res=clamp(res,(FLOAT4)(0),(FLOAT4)(6));\n"
"#endif\n"
" vstore4(res,0,uOutput+out_offset+4*dstWidth+4);\n"
" }\n"
" }\n"
" }\n"
"}\n"
;
#endif
#ifndef MNN_OPENCL_BUFFER_CLOSED
#ifdef MNN_SUPPORT_INTEL_SUBGROUP
const char* winogradTransform_subgroup_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_DIM2 "" __private int global_size_dim0,__private int global_size_dim1,\n"
"#define UNIFORM_BOUNDRY_CHECK(index0, index1) "" if(index0 >= global_size_dim0 || index1 >= global_size_dim1) { "" return; "" }\n"
" \n"
"#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void winoTransSrcBuf2_3_1_c16_c16(GLOBAL_SIZE_DIM2\n"
" __global const FLOAT* uInput,// 0\n"
" __global FLOAT* uOutput,__private const int unitWidth,\n"
" __private const int unitHeight,// 3\n"
" __private const int padX,__private const int padY,\n"
" __private const int srcWidth,// 6\n"
" __private const int srcHeight,__private const int srcChannelC4,__private const int srcChannelC16,__private const int dstHeight,\n"
" __private const int batchOffset,\n"
" __private const int batch,\n"
" __private const int input_pad_left,__private const int input_pad_right) {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1)); \n"
" UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n"
" const int unitWidth_idx=pos.x % unitWidth;\n"
" const int unitHeight_idx=pos.x/unitWidth;\n"
" const int sglid=get_sub_group_local_id();\n"
" const int pos_y=get_group_id(1);\n"
" int2 realPos=(int2)(unitWidth_idx,unitHeight_idx);\n"
" int src_pitch=srcWidth+input_pad_left+input_pad_right;\n"
" \n"
" {\n"
" int sxStart=(realPos.x)*2-padX;\n"
" int syStart=(realPos.y)*2-padY;\n"
" FLOAT4 S[4];\n"
" \n"
" int inp_offset=(((batchOffset*srcChannelC16+pos_y)*srcHeight+syStart)*src_pitch+sxStart+input_pad_left)*16;\n"
" for(int i=0; i<4; ++i){\n"
" int sy=i+syStart;\n"
" if(sy<0 || sy >= srcHeight){\n"
" S[i]=(FLOAT4)0;\n"
" }else{\n"
"#ifdef MNN_SUPPORT_FP16\n"
" S[i]=as_half4(intel_sub_group_block_read_us4((__global ushort*)(uInput+inp_offset)));\n"
"#else\n"
" S[i]=as_float4(intel_sub_group_block_read4((__global uint*)(uInput+inp_offset)));\n"
"#endif\n"
" }\n"
" inp_offset += 16*src_pitch;\n"
" }\n"
" FLOAT m00=+S[0].s0-S[2].s0;\n"
" FLOAT m10=+S[0].s1-S[2].s1;\n"
" FLOAT m20=+S[0].s2-S[2].s2;\n"
" FLOAT m30=+S[0].s3-S[2].s3;\n"
" FLOAT m01=+(FLOAT)0.5f*S[1].s0+(FLOAT)0.5f*S[2].s0;\n"
" FLOAT m11=+(FLOAT)0.5f*S[1].s1+(FLOAT)0.5f*S[2].s1;\n"
" FLOAT m21=+(FLOAT)0.5f*S[1].s2+(FLOAT)0.5f*S[2].s2;\n"
" FLOAT m31=+(FLOAT)0.5f*S[1].s3+(FLOAT)0.5f*S[2].s3;\n"
" FLOAT m02=-(FLOAT)0.5f*S[1].s0+(FLOAT)0.5f*S[2].s0;\n"
" FLOAT m12=-(FLOAT)0.5f*S[1].s1+(FLOAT)0.5f*S[2].s1;\n"
" FLOAT m22=-(FLOAT)0.5f*S[1].s2+(FLOAT)0.5f*S[2].s2;\n"
" FLOAT m32=-(FLOAT)0.5f*S[1].s3+(FLOAT)0.5f*S[2].s3;\n"
" FLOAT m03=-S[1].s0+S[3].s0;\n"
" FLOAT m13=-S[1].s1+S[3].s1;\n"
" FLOAT m23=-S[1].s2+S[3].s2;\n"
" FLOAT m33=-S[1].s3+S[3].s3;\n"
" \n"
" //NC4HW4 [alpha*alpha,srcChannelC16,dstHeight,16]\n"
" //index: [0,pos.y/16,pos.x,0]\n"
" int out_offset=(pos_y*dstHeight+pos.x)*16+sglid;\n"
" int batch_offset=srcChannelC16*dstHeight*16;\n"
" uOutput[out_offset+0*batch_offset]=+m00-m20;\n"
" uOutput[out_offset+1*batch_offset]=+(FLOAT)0.5f*m10+(FLOAT)0.5f*m20;\n"
" uOutput[out_offset+2*batch_offset]=-(FLOAT)0.5f*m10+(FLOAT)0.5f*m20;\n"
" uOutput[out_offset+3*batch_offset]=-m10+m30;\n"
" uOutput[out_offset+4*batch_offset]=+m01-m21;\n"
" uOutput[out_offset+5*batch_offset]=+(FLOAT)0.5f*m11+(FLOAT)0.5f*m21;\n"
" uOutput[out_offset+6*batch_offset]=-(FLOAT)0.5f*m11+(FLOAT)0.5f*m21;\n"
" uOutput[out_offset+7*batch_offset]=-m11+m31;\n"
" uOutput[out_offset+8*batch_offset]=+m02-m22;\n"
" uOutput[out_offset+9*batch_offset]=+(FLOAT)0.5f*m12+(FLOAT)0.5f*m22;\n"
" uOutput[out_offset+10*batch_offset]=-(FLOAT)0.5f*m12+(FLOAT)0.5f*m22;\n"
" uOutput[out_offset+11*batch_offset]=-m12+m32;\n"
" uOutput[out_offset+12*batch_offset]=+m03-m23;\n"
" uOutput[out_offset+13*batch_offset]=+(FLOAT)0.5f*m13+(FLOAT)0.5f*m23;\n"
" uOutput[out_offset+14*batch_offset]=-(FLOAT)0.5f*m13+(FLOAT)0.5f*m23;\n"
" uOutput[out_offset+15*batch_offset]=-m13+m33;\n"
" }\n"
"}\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void winoTransDstBuf2_3_1_c16_c16(GLOBAL_SIZE_DIM2\n"
" __global const FLOAT* uInput,\n"
" __global const FLOAT* uBias,\n"
" __global FLOAT* uOutput,\n"
" __private const int unitWidth,//wUnit\n"
" __private const int unitHeight,//hUnit\n"
" __private const int dstWidth,\n"
" __private const int dstHeight,\n"
" __private const int dstChannelC4,__private const int dstChannelC16,__private const int srcWidth,\n"
" __private const int batchOffset,\n"
" __private const int batch,\n"
" __private const int output_pad_left,__private const int output_pad_right) {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1));\n"
" UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n"
" const int unitWidth_idx=pos.x % unitWidth;\n"
" const int unitHeight_idx=pos.x/unitWidth; \n"
" const int sglid=get_sub_group_local_id();\n"
" const int pos_y=get_group_id(1);\n"
" int2 realPos=(int2)(unitWidth_idx,unitHeight_idx);\n"
" \n"
" FLOAT bias=uBias[pos.y];\n"
" {\n"
" int oyStart=realPos.y*2;\n"
" int oxStart=realPos.x*2;\n"
" \n"
" //NC4HW4 [alpha2,dstChannelC16,wUnit*hUnit,16]\n"
" //index: [0,pos.y/4,pos.x,pos.y%4]\n"
" const int inp_offset=(pos_y*srcWidth+pos.x)*16+sglid;\n"
" const int ic_offset=16*srcWidth*dstChannelC16;\n"
" FLOAT S00=uInput[inp_offset+ic_offset*0];\n"
" FLOAT S10=uInput[inp_offset+ic_offset*1];\n"
" FLOAT S20=uInput[inp_offset+ic_offset*2];\n"
" FLOAT S30=uInput[inp_offset+ic_offset*3];\n"
" FLOAT S01=uInput[inp_offset+ic_offset*4];\n"
" FLOAT S11=uInput[inp_offset+ic_offset*5];\n"
" FLOAT S21=uInput[inp_offset+ic_offset*6];\n"
" FLOAT S31=uInput[inp_offset+ic_offset*7];\n"
" FLOAT S02=uInput[inp_offset+ic_offset*8];\n"
" FLOAT S12=uInput[inp_offset+ic_offset*9];\n"
" FLOAT S22=uInput[inp_offset+ic_offset*10];\n"
" FLOAT S32=uInput[inp_offset+ic_offset*11];\n"
" FLOAT S03=uInput[inp_offset+ic_offset*12];\n"
" FLOAT S13=uInput[inp_offset+ic_offset*13];\n"
" FLOAT S23=uInput[inp_offset+ic_offset*14];\n"
" FLOAT S33=uInput[inp_offset+ic_offset*15];\n"
" FLOAT m00=+S00+S01+S02;\n"
" FLOAT m10=+S10+S11+S12;\n"
" FLOAT m20=+S20+S21+S22;\n"
" FLOAT m30=+S30+S31+S32;\n"
" FLOAT m01=+S01-S02+S03;\n"
" FLOAT m11=+S11-S12+S13;\n"
" FLOAT m21=+S21-S22+S23;\n"
" FLOAT m31=+S31-S32+S33;\n"
" \n"
" //NC4HW4 [batch,dstChannelC4,dstHeight,dstWidth]\n"
" //index: [batchOffset,pos.y,oyStart,oxStart]\n"
" int dst_pitch=dstWidth+output_pad_left+output_pad_right;\n"
" int out_offset=(((batchOffset*dstChannelC16+ pos_y)*dstHeight+oyStart)*dst_pitch+oxStart+output_pad_left)*16+sglid;\n"
" {\n"
" FLOAT2 res=(FLOAT2)(bias+m00+m10+m20,bias+m10-m20+m30);\n"
"#ifdef RELU\n"
" res=max(res,(FLOAT2)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" res=clamp(res,(FLOAT2)0,(FLOAT2)6);\n"
"#endif\n"
"#if OUTPUT_LEFTOVERS\n"
" uOutput[out_offset]=res.x;\n"
" if(oxStart+1< dstWidth){\n"
" uOutput[out_offset+16]=res.y;\n"
" }\n"
"#else\n"
"#ifdef MNN_SUPPORT_FP16\n"
" intel_sub_group_block_write_us2((__global ushort*)(uOutput+out_offset),as_ushort2(res));\n"
"#else\n"
" intel_sub_group_block_write2((__global uint*)(uOutput+out_offset),as_uint2(res));\n"
"#endif\n"
"#endif //OUTPUT_LEFTOVERS\n"
" }\n"
" {\n"
" int oy=oyStart+1;\n"
" if (oy<dstHeight) {\n"
" FLOAT2 res=(FLOAT2)(bias+m01+m11+m21,bias+m11-m21+m31);\n"
"#ifdef RELU\n"
" res=max(res,(FLOAT2)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" res=clamp(res,(FLOAT2)0,(FLOAT2)6);\n"
"#endif\n"
"#if OUTPUT_LEFTOVERS\n"
" uOutput[out_offset+16*dst_pitch]=res.x;\n"
" if(oxStart+1< dstWidth){\n"
" uOutput[out_offset+16+16*dst_pitch]=res.y;\n"
" }\n"
"#else\n"
"#ifdef MNN_SUPPORT_FP16\n"
" intel_sub_group_block_write_us2((__global ushort*)(uOutput+out_offset+16*dst_pitch),as_ushort2(res));\n"
"#else\n"
" intel_sub_group_block_write2((__global uint*)(uOutput+out_offset+16*dst_pitch),as_uint2(res));\n"
"#endif\n"
"#endif //OUTPUT_LEFTOVERS\n"
" }\n"
" }\n"
" if(unitWidth_idx == 0){\n"
" int pad_offset=(((batchOffset*dstChannelC16+ pos_y)*dstHeight+oyStart)*dst_pitch)*16+sglid;\n"
" for(int i=0; i<output_pad_left; ++i){\n"
" uOutput[pad_offset+i*16]=0;\n"
" uOutput[pad_offset+(i+dst_pitch)*16]=0;\n"
" }\n"
" }\n"
" if(unitWidth_idx == unitWidth-1){\n"
" int pad_offset=(((batchOffset*dstChannelC16+ pos_y)*dstHeight+oyStart)*dst_pitch+output_pad_left+dstWidth)*16+sglid;\n"
" for(int i=0; i<output_pad_right; ++i){\n"
" uOutput[pad_offset+i*16]=0;\n"
" uOutput[pad_offset+(i+dst_pitch)*16]=0;\n"
" }\n"
" }\n"
" }\n"
"}\n"
"__kernel void winoTransSrcBuf2_3_1_c4_c16(GLOBAL_SIZE_DIM2\n"
" __global const FLOAT* uInput,// 0\n"
" __global FLOAT* uOutput,__private const int unitWidth,\n"
" __private const int unitHeight,// 3\n"
" __private const int padX,__private const int padY,\n"
" __private const int srcWidth,// 6\n"
" __private const int srcHeight,__private const int srcChannelC4,__private const int srcChannelC16,__private const int dstHeight,\n"
" __private const int batchOffset,\n"
" __private const int batch,\n"
" __private const int input_pad_left,__private const int input_pad_right) {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1)); \n"
" UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n"
" int unitWidth_idx=pos.x % unitWidth;\n"
" int unitHeight_idx=pos.x/unitWidth;\n"
" int2 realPos=(int2)(unitWidth_idx,unitHeight_idx);\n"
" \n"
" {\n"
" int sxStart=(realPos.x)*2-padX;\n"
" int syStart=(realPos.y)*2-padY;\n"
" FLOAT4 S00;\n"
" FLOAT4 S10;\n"
" FLOAT4 S20;\n"
" FLOAT4 S30;\n"
" FLOAT4 S01;\n"
" FLOAT4 S11;\n"
" FLOAT4 S21;\n"
" FLOAT4 S31;\n"
" FLOAT4 S02;\n"
" FLOAT4 S12;\n"
" FLOAT4 S22;\n"
" FLOAT4 S32;\n"
" FLOAT4 S03;\n"
" FLOAT4 S13;\n"
" FLOAT4 S23;\n"
" FLOAT4 S33;\n"
" \n"
" int inp_offset=(((batchOffset+pos.y*batch)*srcHeight+syStart)*srcWidth+sxStart)*4;\n"
" {\n"
" int sx=0+sxStart;\n"
" int sy=0+syStart;\n"
" \n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S00=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset);\n"
" }\n"
" {\n"
" int sx=1+sxStart;\n"
" int sy=0+syStart;\n"
" \n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S10=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+4);\n"
" }\n"
" {\n"
" int sx=2+sxStart;\n"
" int sy=0+syStart;\n"
" \n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S20=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+8);\n"
" }\n"
" {\n"
" int sx=3+sxStart;\n"
" int sy=0+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S30=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+12);\n"
" }\n"
" {\n"
" int sx=0+sxStart;\n"
" int sy=1+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S01=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+4*srcWidth);\n"
" }\n"
" {\n"
" int sx=1+sxStart;\n"
" int sy=1+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S11=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+4*srcWidth+4);\n"
" }\n"
" {\n"
" int sx=2+sxStart;\n"
" int sy=1+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S21=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+4*srcWidth+8);\n"
" }\n"
" {\n"
" int sx=3+sxStart;\n"
" int sy=1+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S31=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+4*srcWidth+12);\n"
" }\n"
" {\n"
" int sx=0+sxStart;\n"
" int sy=2+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S02=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+8*srcWidth);\n"
" }\n"
" {\n"
" int sx=1+sxStart;\n"
" int sy=2+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S12=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+8*srcWidth+4);\n"
" }\n"
" {\n"
" int sx=2+sxStart;\n"
" int sy=2+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S22=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+8*srcWidth+8);\n"
" }\n"
" {\n"
" int sx=3+sxStart;\n"
" int sy=2+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S32=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+8*srcWidth+12);\n"
" }\n"
" {\n"
" int sx=0+sxStart;\n"
" int sy=3+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S03=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+12*srcWidth);\n"
" }\n"
" {\n"
" int sx=1+sxStart;\n"
" int sy=3+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S13=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+12*srcWidth+4);\n"
" }\n"
" {\n"
" int sx=2+sxStart;\n"
" int sy=3+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S23=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+12*srcWidth+8);\n"
" }\n"
" {\n"
" int sx=3+sxStart;\n"
" int sy=3+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S33=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+12*srcWidth+12);\n"
" }\n"
" FLOAT4 m00=+S00-S02;\n"
" FLOAT4 m10=+S10-S12;\n"
" FLOAT4 m20=+S20-S22;\n"
" FLOAT4 m30=+S30-S32;\n"
" FLOAT4 m01=+(FLOAT)0.5f*S01+(FLOAT)0.5f*S02;\n"
" FLOAT4 m11=+(FLOAT)0.5f*S11+(FLOAT)0.5f*S12;\n"
" FLOAT4 m21=+(FLOAT)0.5f*S21+(FLOAT)0.5f*S22;\n"
" FLOAT4 m31=+(FLOAT)0.5f*S31+(FLOAT)0.5f*S32;\n"
" FLOAT4 m02=-(FLOAT)0.5f*S01+(FLOAT)0.5f*S02;\n"
" FLOAT4 m12=-(FLOAT)0.5f*S11+(FLOAT)0.5f*S12;\n"
" FLOAT4 m22=-(FLOAT)0.5f*S21+(FLOAT)0.5f*S22;\n"
" FLOAT4 m32=-(FLOAT)0.5f*S31+(FLOAT)0.5f*S32;\n"
" FLOAT4 m03=-S01+S03;\n"
" FLOAT4 m13=-S11+S13;\n"
" FLOAT4 m23=-S21+S23;\n"
" FLOAT4 m33=-S31+S33;\n"
" \n"
" //NC4HW4 [alpha*alpha,srcChannelC16,dstHeight,16]\n"
" //index: [0,pos.y/4,pos.x,pos.y % 4]\n"
" int out_offset=((pos.y/4)*dstHeight+pos.x)*16+(pos.y % 4)*4;\n"
" int batch_offset=srcChannelC16*dstHeight*16;\n"
" vstore4(+m00-m20,0,uOutput+out_offset+0*batch_offset);\n"
" vstore4(+(FLOAT)0.5f*m10+(FLOAT)0.5f*m20,0,uOutput+out_offset+1*batch_offset);\n"
" vstore4(-(FLOAT)0.5f*m10+(FLOAT)0.5f*m20,0,uOutput+out_offset+2*batch_offset);\n"
" vstore4(-m10+m30,0,uOutput+out_offset+3*batch_offset);\n"
" vstore4(+m01-m21,0,uOutput+out_offset+4*batch_offset);\n"
" vstore4(+(FLOAT)0.5f*m11+(FLOAT)0.5f*m21,0,uOutput+out_offset+5*batch_offset);\n"
" vstore4(-(FLOAT)0.5f*m11+(FLOAT)0.5f*m21,0,uOutput+out_offset+6*batch_offset);\n"
" vstore4(-m11+m31,0,uOutput+out_offset+7*batch_offset);\n"
" vstore4(+m02-m22,0,uOutput+out_offset+8*batch_offset);\n"
" vstore4(+(FLOAT)0.5f*m12+(FLOAT)0.5f*m22,0,uOutput+out_offset+9*batch_offset);\n"
" vstore4(-(FLOAT)0.5f*m12+(FLOAT)0.5f*m22,0,uOutput+out_offset+10*batch_offset);\n"
" vstore4(-m12+m32,0,uOutput+out_offset+11*batch_offset);\n"
" vstore4(+m03-m23,0,uOutput+out_offset+12*batch_offset);\n"
" vstore4(+(FLOAT)0.5f*m13+(FLOAT)0.5f*m23,0,uOutput+out_offset+13*batch_offset);\n"
" vstore4(-(FLOAT)0.5f*m13+(FLOAT)0.5f*m23,0,uOutput+out_offset+14*batch_offset);\n"
" vstore4(-m13+m33,0,uOutput+out_offset+15*batch_offset);\n"
" }\n"
"}\n"
"__kernel void winoTransDstBuf2_3_1_c16_c4(GLOBAL_SIZE_DIM2\n"
" __global const FLOAT* uInput,\n"
" __global const FLOAT* uBias,\n"
" __global FLOAT* uOutput,\n"
" __private const int unitWidth,//wUnit\n"
" __private const int unitHeight,//hUnit\n"
" __private const int dstWidth,\n"
" __private const int dstHeight,\n"
" __private const int dstChannelC4,__private const int dstChannelC16,__private const int srcWidth,\n"
" __private const int batchOffset,\n"
" __private const int batch,\n"
" __private const int output_pad_left,__private const int output_pad_right) {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1));\n"
" UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n"
" int unitWidth_idx=pos.x % unitWidth;\n"
" int unitHeight_idx=pos.x/unitWidth;\n"
" int2 realPos=(int2)(unitWidth_idx,unitHeight_idx);\n"
" \n"
" FLOAT4 bias=vload4(0,uBias+pos.y*4);\n"
" {\n"
" int oyStart=realPos.y*2;\n"
" int oxStart=realPos.x*2;\n"
" \n"
" //NC4HW4 [alpha2,dstChannelC16,wUnit*hUnit,16]\n"
" //index: [0,pos.y/4,pos.x,pos.y%4]\n"
" const int inp_offset=((pos.y/4)*srcWidth+pos.x)*16+(pos.y % 4)*4;\n"
" const int ic_offset=16*srcWidth*dstChannelC16;\n"
" FLOAT4 S00=vload4(0,uInput+inp_offset+ic_offset*0);\n"
" FLOAT4 S10=vload4(0,uInput+inp_offset+ic_offset*1);\n"
" FLOAT4 S20=vload4(0,uInput+inp_offset+ic_offset*2);\n"
" FLOAT4 S30=vload4(0,uInput+inp_offset+ic_offset*3);\n"
" FLOAT4 S01=vload4(0,uInput+inp_offset+ic_offset*4);\n"
" FLOAT4 S11=vload4(0,uInput+inp_offset+ic_offset*5);\n"
" FLOAT4 S21=vload4(0,uInput+inp_offset+ic_offset*6);\n"
" FLOAT4 S31=vload4(0,uInput+inp_offset+ic_offset*7);\n"
" FLOAT4 S02=vload4(0,uInput+inp_offset+ic_offset*8);\n"
" FLOAT4 S12=vload4(0,uInput+inp_offset+ic_offset*9);\n"
" FLOAT4 S22=vload4(0,uInput+inp_offset+ic_offset*10);\n"
" FLOAT4 S32=vload4(0,uInput+inp_offset+ic_offset*11);\n"
" FLOAT4 S03=vload4(0,uInput+inp_offset+ic_offset*12);\n"
" FLOAT4 S13=vload4(0,uInput+inp_offset+ic_offset*13);\n"
" FLOAT4 S23=vload4(0,uInput+inp_offset+ic_offset*14);\n"
" FLOAT4 S33=vload4(0,uInput+inp_offset+ic_offset*15);\n"
" FLOAT4 m00=+S00+S01+S02;\n"
" FLOAT4 m10=+S10+S11+S12;\n"
" FLOAT4 m20=+S20+S21+S22;\n"
" FLOAT4 m30=+S30+S31+S32;\n"
" FLOAT4 m01=+S01-S02+S03;\n"
" FLOAT4 m11=+S11-S12+S13;\n"
" FLOAT4 m21=+S21-S22+S23;\n"
" FLOAT4 m31=+S31-S32+S33;\n"
" \n"
" //NC4HW4 [batch,dstChannelC4,dstHeight,dstWidth]\n"
" //index: [batchOffset,pos.y,oyStart,oxStart]\n"
" int out_offset=(((batchOffset+ pos.y*batch)*dstHeight+oyStart)*dstWidth+oxStart)*4;\n"
" {\n"
" int ox=oxStart+0;\n"
" int oy=oyStart+0;\n"
" if (ox<dstWidth && oy<dstHeight) {\n"
" FLOAT4 res=bias+m00+m10+m20;\n"
"#ifdef RELU\n"
" res=max(res,(FLOAT4)(0));\n"
"#endif\n"
"#ifdef RELU6\n"
" res=clamp(res,(FLOAT4)(0),(FLOAT4)(6));\n"
"#endif\n"
" vstore4(res,0,uOutput+out_offset);\n"
" }\n"
" }\n"
" {\n"
" int ox=oxStart+1;\n"
" int oy=oyStart+0;\n"
" if (ox<dstWidth && oy<dstHeight) {\n"
" FLOAT4 res=bias+m10-m20+m30;\n"
"#ifdef RELU\n"
" res=max(res,(FLOAT4)(0));\n"
"#endif\n"
"#ifdef RELU6\n"
" res=clamp(res,(FLOAT4)(0),(FLOAT4)(6));\n"
"#endif\n"
" vstore4(res,0,uOutput+out_offset+4);\n"
" }\n"
" }\n"
" {\n"
" int ox=oxStart+0;\n"
" int oy=oyStart+1;\n"
" if (ox<dstWidth && oy<dstHeight) {\n"
" FLOAT4 res=bias+m01+m11+m21;\n"
"#ifdef RELU\n"
" res=max(res,(FLOAT4)(0));\n"
"#endif\n"
"#ifdef RELU6\n"
" res=clamp(res,(FLOAT4)(0),(FLOAT4)(6));\n"
"#endif\n"
" vstore4(res,0,uOutput+out_offset+4*dstWidth);\n"
" }\n"
" }\n"
" {\n"
" int ox=oxStart+1;\n"
" int oy=oyStart+1;\n"
" if (ox<dstWidth && oy<dstHeight) {\n"
" FLOAT4 res=bias+m11-m21+m31;\n"
"#ifdef RELU\n"
" res=max(res,(FLOAT4)(0));\n"
"#endif\n"
"#ifdef RELU6\n"
" res=clamp(res,(FLOAT4)(0),(FLOAT4)(6));\n"
"#endif\n"
" vstore4(res,0,uOutput+out_offset+4*dstWidth+4);\n"
" }\n"
" }\n"
" }\n"
"}\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void gemm_buf_intel(__global const FLOAT* input0,\n"
" __global const FLOAT* input1,\n"
" __global FLOAT* output,\n"
" __private const int width,//ROUND_UP(wUnit*hUnit,8)\n"
" __private const int height,//dstChannelC16\n"
" __private const int srcChannelC16,\n"
" __private const int alpha2) {\n"
" int3 pos=(int3)(get_global_id(0),get_group_id(1),get_global_id(2));\n"
" const int sglid=get_sub_group_local_id();\n"
" const int pos_x=pos.x << 3;\n"
" const int pos_y=pos.y;\n"
" FLOAT8 o=(FLOAT8)(0);\n"
" const int kernel_base=mul24(mul24(mad24(pos.z,height,pos_y),srcChannelC16),256);\n"
" const int inp_base=mul24(mad24(mul24(pos.z,srcChannelC16),width,pos_x),16);\n"
" \n"
" for(int k=0; k<srcChannelC16; ++k){\n"
" \n"
"#ifdef MNN_SUPPORT_FP16\n"
" FLOAT8 wei0=as_half8(intel_sub_group_block_read_us8((__global ushort*)(input1+kernel_base+k*256)));\n"
" FLOAT8 wei1=as_half8(intel_sub_group_block_read_us8((__global ushort*)(input1+kernel_base+k*256+8*16)));\n"
" FLOAT8 s=as_half8(intel_sub_group_block_read_us8((__global ushort*)(input0+inp_base+k*width*16)));\n"
" o=mad(wei0.s0,as_half8(intel_sub_group_shuffle(as_ushort8(s),0)),o);\n"
" o=mad(wei0.s1,as_half8(intel_sub_group_shuffle(as_ushort8(s),1)),o);\n"
" o=mad(wei0.s2,as_half8(intel_sub_group_shuffle(as_ushort8(s),2)),o);\n"
" o=mad(wei0.s3,as_half8(intel_sub_group_shuffle(as_ushort8(s),3)),o);\n"
" o=mad(wei0.s4,as_half8(intel_sub_group_shuffle(as_ushort8(s),4)),o);\n"
" o=mad(wei0.s5,as_half8(intel_sub_group_shuffle(as_ushort8(s),5)),o);\n"
" o=mad(wei0.s6,as_half8(intel_sub_group_shuffle(as_ushort8(s),6)),o);\n"
" o=mad(wei0.s7,as_half8(intel_sub_group_shuffle(as_ushort8(s),7)),o);\n"
" o=mad(wei1.s0,as_half8(intel_sub_group_shuffle(as_ushort8(s),8)),o);\n"
" o=mad(wei1.s1,as_half8(intel_sub_group_shuffle(as_ushort8(s),9)),o);\n"
" o=mad(wei1.s2,as_half8(intel_sub_group_shuffle(as_ushort8(s),10)),o);\n"
" o=mad(wei1.s3,as_half8(intel_sub_group_shuffle(as_ushort8(s),11)),o);\n"
" o=mad(wei1.s4,as_half8(intel_sub_group_shuffle(as_ushort8(s),12)),o);\n"
" o=mad(wei1.s5,as_half8(intel_sub_group_shuffle(as_ushort8(s),13)),o);\n"
" o=mad(wei1.s6,as_half8(intel_sub_group_shuffle(as_ushort8(s),14)),o);\n"
" o=mad(wei1.s7,as_half8(intel_sub_group_shuffle(as_ushort8(s),15)),o);\n"
"#else \n"
" FLOAT8 wei0=as_float8(intel_sub_group_block_read8((__global uint*)(input1+kernel_base+k*256)));\n"
" FLOAT8 wei1=as_float8(intel_sub_group_block_read8((__global uint*)(input1+kernel_base+k*256+8*16)));\n"
" FLOAT8 s=as_float8(intel_sub_group_block_read8((__global uint*)(input0+inp_base+k*width*16)));\n"
" o=mad(wei0.s0,intel_sub_group_shuffle(s,0),o);\n"
" o=mad(wei0.s1,intel_sub_group_shuffle(s,1),o);\n"
" o=mad(wei0.s2,intel_sub_group_shuffle(s,2),o);\n"
" o=mad(wei0.s3,intel_sub_group_shuffle(s,3),o);\n"
" o=mad(wei0.s4,intel_sub_group_shuffle(s,4),o);\n"
" o=mad(wei0.s5,intel_sub_group_shuffle(s,5),o);\n"
" o=mad(wei0.s6,intel_sub_group_shuffle(s,6),o);\n"
" o=mad(wei0.s7,intel_sub_group_shuffle(s,7),o);\n"
" o=mad(wei1.s0,intel_sub_group_shuffle(s,8),o);\n"
" o=mad(wei1.s1,intel_sub_group_shuffle(s,9),o);\n"
" o=mad(wei1.s2,intel_sub_group_shuffle(s,10),o);\n"
" o=mad(wei1.s3,intel_sub_group_shuffle(s,11),o);\n"
" o=mad(wei1.s4,intel_sub_group_shuffle(s,12),o);\n"
" o=mad(wei1.s5,intel_sub_group_shuffle(s,13),o);\n"
" o=mad(wei1.s6,intel_sub_group_shuffle(s,14),o);\n"
" o=mad(wei1.s7,intel_sub_group_shuffle(s,15),o);\n"
"#endif \n"
" }\n"
" int out_offset=mul24(mad24(mad24(pos.z,height,pos_y),width,pos_x),16);\n"
"#ifdef MNN_SUPPORT_FP16\n"
" intel_sub_group_block_write_us8((__global ushort*)(output+out_offset),as_ushort8(o));\n"
"#else\n"
" intel_sub_group_block_write8((__global uint*)(output+out_offset),as_uint8(o));\n"
"#endif\n"
"}\n"
;
#endif
#endif
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* splitgelu_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"__kernel void splitgelu_buf(__private int global_dim0,__private int global_dim1,\n"
" __global const FLOAT*input,\n"
" #ifdef DOUBLE_INPUTS\n"
" __global const FLOAT*input1,\n"
" #endif\n"
" __global FLOAT*output,\n"
" __private const int4 shape\n"
"){\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1));\n"
" if (pos.x<global_dim0 && pos.y<global_dim1) {\n"
" const int h=pos.x;\n"
" const int bc=pos.y;\n"
"// The product of W and H is a multiple of 16\n"
"#ifdef WH_16\n"
" const int in_offset=bc*shape.z*2+h*16;\n"
" const int out_offset=bc*shape.z+h*16;\n"
" float16 valueL=convert_float16(vload16(0,input+in_offset));\n"
" float16 valueR=convert_float16(vload16(0,input+in_offset+shape.z));\n"
" #ifdef DOUBLE_INPUTS\n"
" float16 valueConstL=convert_float16(vload16(h,input1));\n"
" float16 valueConstR=convert_float16(vload16(h,input1+shape.z));\n"
" valueL += valueConstL;\n"
" valueR += valueConstR;\n"
" #endif\n"
" float16 out=(erf(valueR*(float16)0.7071067932881648)+(float16)1.0)*valueR*(float16)0.5;\n"
" out *= valueL;\n"
" vstore16(CONVERT_FLOAT16(out),0,output+out_offset);\n"
"// The product of W and H is a multiple of 4\n"
"#elif defined (WH_4)\n"
" const int in_offset=bc*shape.z*2+h*4;\n"
" const int out_offset=bc*shape.z+h*4;\n"
" float4 valueL=convert_float4(vload4(0,input+in_offset));\n"
" float4 valueR=convert_float4(vload4(0,input+in_offset+shape.z));\n"
" #ifdef DOUBLE_INPUTS\n"
" float4 valueConstL=convert_float4(vload4(h,input1));\n"
" float4 valueConstR=convert_float4(vload4(h,input1+shape.z));\n"
" valueL += valueConstL;\n"
" valueR += valueConstR;\n"
" #endif\n"
" float4 out=(erf(valueR*(float4)0.7071067932881648)+(float4)1.0)*valueR*(float4)0.5;\n"
" out *= valueL;\n"
" vstore4(CONVERT_FLOAT4(out),0,output+out_offset);\n"
"#else\n"
" const int in_offset=bc*shape.z*2+h;\n"
" const int out_offset=bc*shape.z+h;\n"
" \n"
" float valueL=(float)input[in_offset];\n"
" float valueR=(float)input[in_offset+shape.z];\n"
" #ifdef DOUBLE_INPUTS\n"
" float valueConstL=input1[h];\n"
" float valueConstR=input1[shape.z+h];\n"
" valueL += valueConstL;\n"
" valueR += valueConstR;\n"
" #endif\n"
" float out=(erf(valueR*0.7071067932881648)+1.0)*valueR*0.5;\n"
" out *= valueL;\n"
" output[out_offset]=out;\n"
"#endif\n"
" }\n"
"}\n"
;
#endif
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* select_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_2_DIMS ""__private const int global_size_dim0,__private const int global_size_dim1,\n"
"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n"
"__kernel void select_buf(GLOBAL_SIZE_2_DIMS\n"
" __global const int* select,\n"
" __global const FLOAT* input0,\n"
" __global const FLOAT* input1,\n"
" __global FLOAT* output\n"
" ) {\n"
" const int idx=get_global_id(0);\n"
" const int idy=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(idx,idy);\n"
" if (select[idx]) {\n"
"#ifdef INSIZE1_EUQAL_1\n"
" output[idx]=input0[0];\n"
"#else\n"
" output[idx]=input0[idx];\n"
"#endif\n"
" } else {\n"
"#ifdef INSIZE2_EUQAL_1\n"
" output[idx]=input1[0];\n"
"#else\n"
" output[idx]=input1[idx];\n"
"#endif\n"
" }\n"
"}\n"
;
#endif
const char* grid_sample = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"enum BorderMode {\n"
" BorderMode_ZEROS=0,\n"
" BorderMode_CLAMP=1,\n"
" BorderMode_REFLECTION=2,\n"
" BorderMode_MIN=BorderMode_ZEROS,\n"
" BorderMode_MAX=BorderMode_REFLECTION\n"
"};\n"
"float getPosition(float x,int range,int alignCorners){\n"
" float a=alignCorners == 1? 1.0f : 0.0f;\n"
" float b=alignCorners == 1? 0.0f : 1.0f;\n"
" return ((1+x)*(range-a)-b)/2.0f;\n"
"}\n"
"static int CLAMP(int v,int min,int max) {\n"
" if ((v)<min) {\n"
" (v)=min;\n"
" } else if ((v)>max) {\n"
" (v)=max;\n"
" }\n"
" return v;\n"
"}\n"
"FLOAT4 sample(int h,int w,\n"
" const int w_offset_base,\n"
" const int h_offset_base,\n"
" __read_only image2d_t tmp,\n"
" int height,int width,\n"
" enum BorderMode paddingMode){\n"
" if (h<0 || h >= height || w<0 || w >= width) {\n"
" if(paddingMode == BorderMode_ZEROS)\n"
" {\n"
" return 0.0f;\n"
" }\n"
" // Clearly,CLAMP is the right way to go for GridSamplePaddingMode_BORDER\n"
" // For GridSamplePaddingMode_REFLECTION,since we have reflected the values into (-1,1),\n"
" // the leftover reflections degrade to GridSamplePaddingMode_BORDER\n"
" h=CLAMP(h,0,height-1);\n"
" w=CLAMP(w,0,width-1);\n"
" }\n"
" return RI_F(tmp,SAMPLER,(int2)(w_offset_base+w,h_offset_base+h));\n"
"}\n"
"__kernel void nearest(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,\n"
" __read_only image2d_t grid,\n"
" __write_only image2d_t output,\n"
" __private const int input_height,\n"
" __private const int input_width,\n"
" __private const int output_height,\n"
" __private const int output_width,\n"
" __private const enum BorderMode paddingMode,\n"
" __private const int alignCorners\n"
" ){\n"
" const int output_channel_block_idx=get_global_id(0);\n"
" const int output_width_block_idx=get_global_id(1);\n"
" const int output_batch_height_block_idx=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(output_channel_block_idx,output_width_block_idx,output_batch_height_block_idx);\n"
" const int output_batch_idx=output_batch_height_block_idx/output_height;\n"
" const int output_height_idx=output_batch_height_block_idx % output_height;\n"
" // grid data format has been converted from nchw to nc4hw4\n"
" /* slice slice\n"
" (x1,y1)...(xn,y1) (x1,x1,x1,x1) (y1,y2,y3,y4) | (x1,x1,x1,x1) (y5,y6,y7,y8) | ... \n"
" . . . . | . . |\n"
" . . <-> . . | . . |\n"
" . . . . | . . |\n"
" (x1,ym)...(xn,ym) (xn,xn,xn,xn) (y1,y2,y3,y4) | (xn,xn,xn,xn) (y5,y6,y7,y8) | ...\n"
" */\n"
" const int slice=output_height_idx/4;\n"
" const int grid_w_offset=0;\n"
" const int grid_h_offset=mad24(output_batch_idx,output_width,output_width_block_idx);\n"
" \n"
" FLOAT4 grid_x=RI_F(grid,SAMPLER,(int2)(grid_w_offset+2*slice,grid_h_offset));\n"
" FLOAT4 grid_y=RI_F(grid,SAMPLER,(int2)(grid_w_offset+1+2*slice,grid_h_offset));\n"
" const float arr[8]={grid_x.x,grid_y.x,grid_x.y,grid_y.y,grid_x.z,grid_y.z,grid_x.w,grid_y.w};\n"
" \n"
" // get grid x,y\n"
" const int arr_offset=output_height_idx % 4;\n"
" const float x=arr[2*arr_offset];\n"
" const float y=arr[2*arr_offset+1];\n"
" // convert grid x,y to input coordinate range\n"
" float in_grid_x=getPosition(x,input_width,alignCorners);\n"
" float in_grid_y=getPosition(y,input_height,alignCorners);\n"
" // get nearest point\n"
" int nw=floor(in_grid_x+0.5f);\n"
" int nh=floor(in_grid_y+0.5f);\n"
" const int inp_w_offset=mul24(output_channel_block_idx,input_width);\n"
" const int inp_h_offset=mul24(output_batch_idx,input_height);\n"
" FLOAT4 value=sample(nh,nw,inp_w_offset,inp_h_offset,input,input_height,input_width,paddingMode);\n"
" const int output_w_offset=mad24(output_channel_block_idx,output_width,output_width_block_idx);\n"
" const int output_h_offset=mad24(output_batch_idx,output_height,output_height_idx);\n"
" WI_F(output,(int2)(output_w_offset,output_h_offset),value);\n"
"}\n"
"__kernel void bilinear(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,\n"
" __read_only image2d_t grid,\n"
" __write_only image2d_t output,\n"
" __private const int input_height,\n"
" __private const int input_width,\n"
" __private const int output_height,\n"
" __private const int output_width,\n"
" __private const enum BorderMode paddingMode,\n"
" __private const int alignCorners\n"
" ){\n"
" const int output_channel_block_idx=get_global_id(0);\n"
" const int output_width_block_idx=get_global_id(1);\n"
" const int output_batch_height_block_idx=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(output_channel_block_idx,output_width_block_idx,output_batch_height_block_idx);\n"
" const int output_batch_idx=output_batch_height_block_idx/output_height;\n"
" const int output_height_idx=output_batch_height_block_idx % output_height;\n"
" // get grid idx\n"
" const int slice=output_height_idx/4;\n"
" const int grid_w_offset=0;\n"
" const int grid_h_offset=mad24(output_batch_idx,output_width,output_width_block_idx);\n"
" \n"
" FLOAT4 grid_x=RI_F(grid,SAMPLER,(int2)(grid_w_offset+2*slice,grid_h_offset));\n"
" FLOAT4 grid_y=RI_F(grid,SAMPLER,(int2)(grid_w_offset+1+2*slice,grid_h_offset));\n"
" const float arr[8]={grid_x.x,grid_y.x,grid_x.y,grid_y.y,grid_x.z,grid_y.z,grid_x.w,grid_y.w};\n"
" \n"
" // get grid x,y\n"
" const int arr_offset=output_height_idx % 4;\n"
" const float x=arr[2*arr_offset];\n"
" const float y=arr[2*arr_offset+1];\n"
" // convert grid x,y to input coordinate range\n"
" float in_grid_x=getPosition(x,input_width,alignCorners);\n"
" float in_grid_y=getPosition(y,input_height,alignCorners);\n"
" int in_h0=floor(in_grid_y);\n"
" int in_w0=floor(in_grid_x);\n"
" int in_h1=ceil(in_grid_y);\n"
" int in_w1=ceil(in_grid_x);\n"
" float x_weight=in_w1-in_grid_x;\n"
" float y_weight=in_h1-in_grid_y;\n"
" const int inp_w_offset=mul24(output_channel_block_idx,input_width);\n"
" const int inp_h_offset=mul24(output_batch_idx,input_height);\n"
" FLOAT4 i00=sample(in_h0,in_w0,inp_w_offset,inp_h_offset,input,input_height,input_width,paddingMode);\n"
" FLOAT4 i01=sample(in_h0,in_w1,inp_w_offset,inp_h_offset,input,input_height,input_width,paddingMode);\n"
" FLOAT4 i10=sample(in_h1,in_w0,inp_w_offset,inp_h_offset,input,input_height,input_width,paddingMode);\n"
" FLOAT4 i11=sample(in_h1,in_w1,inp_w_offset,inp_h_offset,input,input_height,input_width,paddingMode);\n"
" // bilinear interpolation\n"
" FLOAT4 value=CONVERT_FLOAT4(((FLOAT4)x_weight*CONVERT_FLOAT4(i00)+(FLOAT4)(1.0f-x_weight)*CONVERT_FLOAT4(i01))*(FLOAT4)y_weight +\n"
" ((FLOAT4)x_weight*CONVERT_FLOAT4(i10)+(FLOAT4)(1.0f-x_weight)*CONVERT_FLOAT4(i11))*(FLOAT4)(1.0f- y_weight));\n"
" const int output_w_offset=mad24(output_channel_block_idx,output_width,output_width_block_idx);\n"
" const int output_h_offset=mad24(output_batch_idx,output_height,output_height_idx);\n"
" WI_F(output,(int2)(output_w_offset,output_h_offset),value);\n"
"}\n"
;
const char* buffer_convert_quant = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0,__private const int global_size_dim1,\n"
"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n"
"#ifdef USE_LOW_BIT_WEIGHT_INT8\n"
"// convert kernel : from int8 buffer(oihw) to int8 image(oc/4 h w ,ic oc4)\n"
"__kernel void conv2d_filter_buffer_to_nc4hw4_buffer_int8(GLOBAL_SIZE_2_DIMS\n"
" __global const char *input_ptr,\n"
" __private const int output_channel,\n"
" __private const int2 kernel_shape,\n"
" __private const int ic_h_w_size,\n"
" __private const int height_width_size,\n"
" __global char *output) {\n"
" int image_width_idx=get_global_id(0); // ic\n"
" int image_height_idx=get_global_id(1); // oc/4 h w\n"
" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n"
" const int input_channel_4_idx=image_width_idx;\n"
" const int output_channel_4_idx=(image_height_idx/height_width_size)*4;\n"
" const int height_width_idx=image_height_idx % height_width_size;\n"
" const int buffer_height_idx=height_width_idx/kernel_shape.y;\n"
" const int buffer_width_idx=height_width_idx % kernel_shape.y;\n"
" const int buffer_offset=output_channel_4_idx*ic_h_w_size+input_channel_4_idx*height_width_size +\n"
" buffer_height_idx*kernel_shape.y+buffer_width_idx;\n"
" char4 output_values=0;\n"
" if (output_channel_4_idx<output_channel) {\n"
" const int remain_channel=output_channel-output_channel_4_idx;\n"
" if (remain_channel >= 4) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(char)(*(input_ptr+offset));\n"
" offset=mad24(1,ic_h_w_size,offset);\n"
" output_values.y=(char)(*(input_ptr+offset));\n"
" offset += ic_h_w_size;\n"
" output_values.z=(char)(*(input_ptr+offset));\n"
" offset += ic_h_w_size;\n"
" output_values.w=(char)(*(input_ptr+offset));\n"
" } else if (remain_channel == 3) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(char)(*(input_ptr+offset));\n"
" offset=mad24(1,ic_h_w_size,offset);\n"
" output_values.y=(char)(*(input_ptr+offset));\n"
" offset += ic_h_w_size;\n"
" output_values.z=(char)(*(input_ptr+offset));\n"
" } else if (remain_channel == 2) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(char)(*(input_ptr+offset));\n"
" offset=mad24(1,ic_h_w_size,offset);\n"
" output_values.y=(char)(*(input_ptr+offset));\n"
" } else if (remain_channel == 1) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(char)(*(input_ptr+offset));\n"
" }\n"
" }\n"
" const int out_offset=(image_width_idx*height_width_size*((output_channel+3)/4)+image_height_idx)*4;\n"
" vstore4(output_values,0,output+out_offset);\n"
"}\n"
"#endif\n"
"#ifdef USE_LOW_BIT_WEIGHT_INT4\n"
"// convert kernel : from int8 buffer(oihw) to int4 image(oc/4 h w ,ic oc4)\n"
"__kernel void conv2d_filter_buffer_to_nc4hw4_buffer_int4(GLOBAL_SIZE_2_DIMS\n"
" __global const uchar *input_ptr,\n"
" __private const int output_channel,\n"
" __private const int2 kernel_shape,\n"
" __private const int ic_h_w_size,\n"
" __private const int height_width_size,\n"
" __global uchar *output) {\n"
" int image_width_idx=get_global_id(0); // ic\n"
" int image_height_idx=get_global_id(1); // oc/4 h w\n"
" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n"
" const int input_channel_4_idx=image_width_idx;\n"
" const int output_channel_4_idx=(image_height_idx/height_width_size)*4;\n"
" const int height_width_idx=image_height_idx % height_width_size;\n"
" const int buffer_height_idx=height_width_idx/kernel_shape.y;\n"
" const int buffer_width_idx=height_width_idx % kernel_shape.y;\n"
" const int buffer_offset=output_channel_4_idx*ic_h_w_size+input_channel_4_idx*height_width_size+buffer_height_idx*kernel_shape.y+buffer_width_idx;\n"
" int index0=buffer_offset,index1=buffer_offset+ic_h_w_size,index2=buffer_offset+2*ic_h_w_size,index3=buffer_offset+3*ic_h_w_size;\n"
" uchar2 output_values_int4=(uchar2)(0,0);\n"
" uchar s0=input_ptr[index0/2];\n"
" uchar s1=output_channel_4_idx+1 >= output_channel ? 0 : input_ptr[index1/2];\n"
" uchar s2=output_channel_4_idx+1 >= output_channel ? 0 : input_ptr[index2/2];\n"
" uchar s3=output_channel_4_idx+1 >= output_channel ? 0 : input_ptr[index3/2];\n"
" output_values_int4.x=((index0 % 2) == 0 ? (s0 & 0xf0) : (s0 << 4)) | ((index1 % 2) == 0 ? (s1 >> 4) : (s1 & 0x0f));\n"
" output_values_int4.y=((index2 % 2) == 0 ? (s2 & 0xf0) : (s2 << 4)) | ((index3 % 2) == 0 ? (s3 >> 4) : (s3 & 0x0f));\n"
" const int out_offset=(image_width_idx*height_width_size*((output_channel+3)/4)+image_height_idx)*2;\n"
" vstore2(output_values_int4,0,output+out_offset);\n"
"}\n"
"#endif\n"
"#define CHAR16_TO_UCHAR8(a, b) "" a=(uchar8)(((b.s0+8) << 4)+b.s1+8,((b.s2+8) << 4)+b.s3+8,((b.s4+8) << 4)+b.s5+8,((b.s6+8) << 4)+b.s7+8,((b.s8+8) << 4)+b.s9+8,((b.sa+8) << 4)+b.sb+8,((b.sc+8) << 4)+b.sd+8,((b.se+8) << 4)+b.sf+8);\n"
"#define CHAR32_TO_UCHAR16(a, b, c) "" a = (uchar16)(((b.s0 + 8) << 4) + b.s1 + 8, ((b.s2 + 8) << 4) + b.s3 + 8, ((b.s4 + 8) << 4) + b.s5 + 8, ((b.s6 + 8) << 4) + b.s7 + 8, ((b.s8 + 8) << 4) + b.s9 + 8, ((b.sa + 8) << 4) + b.sb + 8, ((b.sc + 8) << 4) + b.sd + 8, ((b.se + 8) << 4) + b.sf + 8, "" ((c.s0+8) << 4)+c.s1+8,((c.s2+8) << 4)+c.s3+8,((c.s4+8) << 4)+c.s5+8,((c.s6+8) << 4)+c.s7+8,((c.s8+8) << 4)+c.s9+8,((c.sa+8) << 4)+c.sb+8,((c.sc+8) << 4)+c.sd+8,((c.se+8) << 4)+c.sf+8);\n"
"__kernel void conv2d_1x1_weight_quant_buffer(GLOBAL_SIZE_2_DIMS\n"
"#ifdef USE_LOW_BIT_WEIGHT_INT4\n"
" __global const uchar *input_ptr,\n"
"#else\n"
" __global const char *input_ptr,\n"
"#endif\n"
" __global char *output_ptr,\n"
" __private const int input_channel,\n"
" __private const int output_channel) {\n"
" int x=get_global_id(0); // ic/16\n"
" int y=get_global_id(1); // oc\n"
" DEAL_NON_UNIFORM_DIM2(x,y);\n"
" const int xin=x << 4;\n"
" const int outputChannelC4=(output_channel+3) >> 2;\n"
"#ifdef USE_LOW_BIT_WEIGHT_INT4\n"
" const int outputOffset=((x*outputChannelC4*4*8+y*8));\n"
"#ifdef CHANNEL_LEAVE\n"
" for(int i=0; i<8; ++i){\n"
" int index0=y*input_channel+xin+i*2;\n"
" int index1=y*input_channel+xin+i*2+1;\n"
" uchar s0=input_ptr[index0/2];\n"
" uchar s1=input_ptr[index1/2];\n"
" output_ptr[outputOffset+i]=((index0 % 2) == 0 ? (s0 & 0xf0) : (s0 << 4)) | ((index1 % 2) == 0 ? (s1 >> 4) : (s1 & 0x0f));\n"
" }\n"
"#else\n"
" const int inputOffset=(y*input_channel+xin)/2;\n"
" vstore8(convert_char8(vload8(0,input_ptr+inputOffset)),0,output_ptr+outputOffset);\n"
"#endif\n"
"#else\n"
" const int inputOffset=y*input_channel+xin;\n"
" const int outputOffset=(x*outputChannelC4*4+y) << 4;\n"
" vstore16(convert_char16(vload16(0,input_ptr+inputOffset)),0,output_ptr+outputOffset);\n"
"#endif\n"
"}\n"
"__kernel void conv2d_1x1_weight_quant_image(GLOBAL_SIZE_2_DIMS\n"
"#ifdef USE_LOW_BIT_WEIGHT_INT4\n"
" __global const uchar *input_ptr,\n"
"#else\n"
" __global const uchar *input_ptr,\n"
"#endif\n"
" __write_only image2d_t output,\n"
" __private const int input_channel,\n"
" __private const int output_channel) {\n"
" int x=get_global_id(0); // ic/32\n"
" int y=get_global_id(1); // oc\n"
" DEAL_NON_UNIFORM_DIM2(x,y);\n"
"#ifdef USE_LOW_BIT_WEIGHT_INT4\n"
" const int xin=x << 5;\n"
"#ifdef CHANNEL_LEAVE\n"
" uchar16 out=0;\n"
" uchar *out_ptr=(uchar*)&out;\n"
" for(int i=0; i<16; ++i){\n"
" int index0=y*input_channel+xin+i*2;\n"
" int index1=y*input_channel+xin+i*2+1;\n"
" uchar s0=input_ptr[index0/2];\n"
" uchar s1=input_ptr[index1/2];\n"
" out_ptr[i]=((index0 % 2) == 0 ? (s0 & 0xf0) : (s0 << 4)) | ((index1 % 2) == 0 ? (s1 >> 4) : (s1 & 0x0f));\n"
" }\n"
" write_imagei(output,(int2)(y,x),as_int4(out));\n"
"#else\n"
" const int inputOffset=(y*input_channel+xin)/2;\n"
" write_imagei(output,(int2)(y,x),as_int4(vload16(0,input_ptr+inputOffset)));\n"
"#endif\n"
"#else\n"
" const int xin=x << 4;\n"
" const int inputOffset=y*input_channel+xin;\n"
" write_imagei(output,(int2)(y,x),as_int4(vload16(0,input_ptr+inputOffset)));\n"
"#endif\n"
"}\n"
"__kernel void conv2d_1x1_ic_oc_weight_quant_buffer(GLOBAL_SIZE_2_DIMS\n"
"#ifdef USE_LOW_BIT_WEIGHT_INT4\n"
" __global const uchar *input_ptr,\n"
" __global uchar *output_ptr,//(Ci/packCin， Co/packCout,packCin， packCout)\n"
"#else\n"
" __global const char *input_ptr,\n"
" __global char *output_ptr,//(Ci/packCin， Co/packCout,packCin， packCout)\n"
"#endif\n"
" __private const int input_channel,\n"
" __private const int output_channel,\n"
" __private const int icPack,\n"
" __private const int ocPack) {\n"
" int x=get_global_id(0); // ic/icPack\n"
" int y=get_global_id(1); // oc/ocPack\n"
" DEAL_NON_UNIFORM_DIM2(x,y);\n"
" const int xin=x*icPack;\n"
" const int yin=y*ocPack;\n"
" const int inputChannelC4=(input_channel+icPack-1)/icPack;\n"
" const int outputChannelC4=(output_channel+ocPack-1)/ocPack;\n"
"#ifdef USE_LOW_BIT_WEIGHT_INT4\n"
" const int inputOffset=(yin*input_channel+xin)/2;\n"
" const int outputOffset=((x*outputChannelC4+y)*icPack*ocPack)/2;\n"
" for(int i=0; i<icPack; ++i){\n"
" for(int j=0; j<ocPack/2; ++j){\n"
" int index0=(yin+j*2)*input_channel+xin+i;\n"
" int index1=(yin+j*2+1)*input_channel+xin+i;\n"
" uchar s0=input_ptr[index0/2];\n"
" uchar s1=input_ptr[index1/2];\n"
" s0=(index0 % 2) == 0 ? (s0 & 0xf0) : ((s0 & 0x0f) << 4);\n"
" s1=(index1 % 2) == 0 ? (s1 >> 4) : (s1 & 0x0f);\n"
" output_ptr[outputOffset+i*(ocPack/2)+j]=s0 | s1;\n"
" }\n"
" }\n"
"#else\n"
" const int inputOffset=yin*input_channel+xin;\n"
" const int outputOffset=(x*outputChannelC4+y)*icPack*ocPack;\n"
" for(int i=0; i<icPack; ++i){\n"
" for(int j=0; j<ocPack; ++j){\n"
" output_ptr[outputOffset+i*ocPack+j]=input_ptr[inputOffset+j*input_channel+i];\n"
" }\n"
" }\n"
"#endif\n"
"}\n"
;
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* gemm_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_DIM2 "" __private int global_size_dim0,__private int global_size_dim1,\n"
"#define UNIFORM_BOUNDRY_CHECK(index0, index1) "" if(index0 >= global_size_dim0 || index1 >= global_size_dim1) { "" return; "" }\n"
"// [K/4,M,4] -> [alignK,alignM]\n"
"__kernel void transpose_pad(GLOBAL_SIZE_DIM2\n"
" const int alignM,\n"
" const int alignK,\n"
" const int M,\n"
" const int K,\n"
" const int area,\n"
" __global const FLOAT* input,\n"
" __global FLOAT* output\n"
" ) {\n"
" const int idx_m4=get_global_id(0); // idx M\n"
" const int idx_k4=get_global_id(1); // idx K\n"
" UNIFORM_BOUNDRY_CHECK(idx_m4,idx_k4);\n"
" const int idx_m=idx_m4 << 2;\n"
" const int idx_k=idx_k4 << 2;\n"
" const int K_4=(K+3) >> 2;\n"
" const int in_offset_base=(idx_k4*M+idx_m)*4;\n"
" const int out_offset_base=idx_k*alignM+idx_m;\n"
" \n"
" FLOAT4 m0k4=(idx_k4 >= K_4 || idx_m+0 >= M) ? (FLOAT4)0 : vload4(0,input+in_offset_base);\n"
" FLOAT4 m1k4=(idx_k4 >= K_4 || idx_m+1 >= M) ? (FLOAT4)0 : vload4(0,input+in_offset_base+4);\n"
" FLOAT4 m2k4=(idx_k4 >= K_4 || idx_m+2 >= M) ? (FLOAT4)0 : vload4(0,input+in_offset_base+8);\n"
" FLOAT4 m3k4=(idx_k4 >= K_4 || idx_m+3 >= M) ? (FLOAT4)0 : vload4(0,input+in_offset_base+12);\n"
" \n"
" vstore4((FLOAT4)(m0k4.x,m1k4.x,m2k4.x,m3k4.x),0,output+out_offset_base);\n"
" vstore4((FLOAT4)(m0k4.y,m1k4.y,m2k4.y,m3k4.y),0,output+out_offset_base+alignM);\n"
" vstore4((FLOAT4)(m0k4.z,m1k4.z,m2k4.z,m3k4.z),0,output+out_offset_base+alignM+alignM);\n"
" vstore4((FLOAT4)(m0k4.w,m1k4.w,m2k4.w,m3k4.w),0,output+out_offset_base+alignM+alignM+alignM);\n"
"}\n"
"#ifndef M_VEC\n"
"#define M_VEC 1\n"
"#endif\n"
"// [alignM,alignN] -> [N/4,B,area,N4] (M=B*area)\n"
"__kernel void transpose_bias(GLOBAL_SIZE_DIM2\n"
" const int alignM,\n"
" const int alignN,\n"
" const int M,\n"
" const int N,\n"
" const int area,\n"
" __global const FLOAT* input0,\n"
" __global const FLOAT* input1,\n"
" __global FLOAT* output\n"
" ) {\n"
" int idx_m=get_global_id(0); // idx M\n"
" int idx_n4=get_global_id(1); // idx N\n"
" UNIFORM_BOUNDRY_CHECK(idx_m,idx_n4);\n"
" const int idx_n=idx_n4 << 2;\n"
" idx_m=idx_m*M_VEC;\n"
" FLOAT4 res1=vload4(0,input1+idx_n);\n"
" #pragma unroll\n"
" for(int i=0; i<M_VEC; i++) {\n"
" FLOAT4 res0=vload4(0,input0+(idx_m+i)*alignN+idx_n);\n"
" FLOAT4 res=res0+res1;\n"
" #ifdef RELU\n"
" res=fmax(res,(FLOAT4)0);\n"
" #endif\n"
" #ifdef RELU6\n"
" res=clamp(res,(FLOAT4)0,(FLOAT4)6);\n"
" #endif\n"
" vstore4(res,0,output+((idx_n4*M+idx_m+i) << 2));\n"
" }\n"
"}\n"
;
#endif
const char* copy_buffer_to_image2d = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void copy_buffer_to_image2d(\n"
" #ifdef BUFFER_INP_FP32\n"
" __global const float4* input,\n"
" #else\n"
" __global const FLOAT4* input,\n"
" #endif\n"
" __write_only image2d_t uOutput,\n"
" __private const int width,__private const int height) {\n"
" int x=get_global_id(0);\n"
" int y=get_global_id(1);\n"
" if (x<width && y<height) {\n"
" WI_F(uOutput,(int2)(x,y),(FLOAT4)((FLOAT)input[x+y*width].x,(FLOAT)input[x+y*width].y,(FLOAT)input[x+y*width].z,(FLOAT)input[x+y*width].w));\n"
" }\n"
"}\n"
;
const char* loop = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define PI 3.141592653589f\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void batch_matmul(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __global FLOAT* output,__global FLOAT* input_A,__global FLOAT* input_B,\n"
"#ifdef BIAS\n"
" __global FLOAT* input_C,\n"
"#endif\n"
" __global int* offset_O,__global int* offset_A,__global int* offset_B,\n"
"#ifdef BIAS\n"
" __global int* offset_C,\n"
"#endif\n"
" __private const int e,\n"
" __private const int l,\n"
" __private const int h,\n"
" __private const int4 offsets,\n"
" __private const int4 iters,\n"
" __private const int4 steps) {\n"
" int3 pos=(int3)(get_global_id(0),get_global_id(1),get_global_id(2));\n"
" if (pos.x<global_dim0 && pos.y<global_dim1 && pos.z<global_dim2) {\n"
" pos.x <<= 2;\n"
" pos.y <<= 2;\n"
" int4 index=(int4)(pos.z);\n"
" if (iters.x >= 0) {\n"
" index.x=offset_O[pos.z];\n"
" }\n"
" if (iters.y >= 0) {\n"
" index.y=offset_A[pos.z];\n"
" }\n"
" if (iters.z >= 0) {\n"
" index.z=offset_B[pos.z];\n"
" }\n"
"#ifdef BIAS\n"
" if (iters.w >= 0) {\n"
" index.w=offset_C[pos.z];\n"
" }\n"
"#endif\n"
" int4 offset=index*steps+offsets;\n"
" \n"
"#if TRANSPOSE_A\n"
" __global FLOAT* A_ptr=input_A+offset.y+pos.y;\n"
"#else\n"
" __global FLOAT* A_ptr=input_A+offset.y+pos.y*l;\n"
"#endif\n"
"#if TRANSPOSE_B\n"
" __global FLOAT* B_ptr=input_B+offset.z+pos.x*l;\n"
"#else\n"
" __global FLOAT* B_ptr=input_B+offset.z+pos.x;\n"
"#endif\n"
"#ifdef BIAS\n"
" FLOAT4 value0=vload4(0,input_C+offset.w+pos.x);\n"
" FLOAT4 value1=value0;\n"
" FLOAT4 value2=value0;\n"
" FLOAT4 value3=value0;\n"
"#else\n"
" FLOAT4 value0=(FLOAT4)0;\n"
" FLOAT4 value1=(FLOAT4)0;\n"
" FLOAT4 value2=(FLOAT4)0;\n"
" FLOAT4 value3=(FLOAT4)0;\n"
"#endif\n"
" const int l_pack=(l+3) >> 2;\n"
" for(int i=0; i<l_pack-1; ++i){\n"
" int l_offset=i << 2;\n"
" FLOAT4 value_a0,value_a1,value_a2,value_a3,value_b0,value_b1,value_b2,value_b3;\n"
"#if TRANSPOSE_A\n"
" value_a0=vload4(0,A_ptr+l_offset*e);\n"
" value_a1=vload4(0,A_ptr+(l_offset+1)*e);\n"
" value_a2=vload4(0,A_ptr+(l_offset+2)*e);\n"
" value_a3=vload4(0,A_ptr+(l_offset+3)*e);\n"
"#else\n"
" value_a0=vload4(0,A_ptr+l_offset);\n"
" value_a1=vload4(0,A_ptr+l_offset+l);\n"
" value_a2=vload4(0,A_ptr+l_offset+2*l);\n"
" value_a3=vload4(0,A_ptr+l_offset+3*l);\n"
"#endif\n"
"#if TRANSPOSE_B\n"
" FLOAT4 value_tmp0=vload4(0,B_ptr+l_offset);\n"
" FLOAT4 value_tmp1=vload4(0,B_ptr+l_offset+l);\n"
" FLOAT4 value_tmp2=vload4(0,B_ptr+l_offset+2*l);\n"
" FLOAT4 value_tmp3=vload4(0,B_ptr+l_offset+3*l);\n"
" value_b0=(FLOAT4)(value_tmp0.x,value_tmp1.x,value_tmp2.x,value_tmp3.x);\n"
" value_b1=(FLOAT4)(value_tmp0.y,value_tmp1.y,value_tmp2.y,value_tmp3.y);\n"
" value_b2=(FLOAT4)(value_tmp0.z,value_tmp1.z,value_tmp2.z,value_tmp3.z);\n"
" value_b3=(FLOAT4)(value_tmp0.w,value_tmp1.w,value_tmp2.w,value_tmp3.w);\n"
"#else\n"
" value_b0=vload4(0,B_ptr+l_offset*h);\n"
" value_b1=vload4(0,B_ptr+(l_offset+1)*h);\n"
" value_b2=vload4(0,B_ptr+(l_offset+2)*h);\n"
" value_b3=vload4(0,B_ptr+(l_offset+3)*h);\n"
"#endif\n"
"#ifdef TRANSPOSE_A\n"
" value0=mad((FLOAT4)value_a0.x,value_b0,value0);\n"
" value0=mad((FLOAT4)value_a1.x,value_b1,value0);\n"
" value0=mad((FLOAT4)value_a2.x,value_b2,value0);\n"
" value0=mad((FLOAT4)value_a3.x,value_b3,value0);\n"
" \n"
" value1=mad((FLOAT4)value_a0.y,value_b0,value1);\n"
" value1=mad((FLOAT4)value_a1.y,value_b1,value1);\n"
" value1=mad((FLOAT4)value_a2.y,value_b2,value1);\n"
" value1=mad((FLOAT4)value_a3.y,value_b3,value1);\n"
" \n"
" value2=mad((FLOAT4)value_a0.z,value_b0,value2);\n"
" value2=mad((FLOAT4)value_a1.z,value_b1,value2);\n"
" value2=mad((FLOAT4)value_a2.z,value_b2,value2);\n"
" value2=mad((FLOAT4)value_a3.z,value_b3,value2);\n"
" \n"
" value3=mad((FLOAT4)value_a0.w,value_b0,value3);\n"
" value3=mad((FLOAT4)value_a1.w,value_b1,value3);\n"
" value3=mad((FLOAT4)value_a2.w,value_b2,value3);\n"
" value3=mad((FLOAT4)value_a3.w,value_b3,value3);\n"
"#else\n"
" value0=mad((FLOAT4)value_a0.x,value_b0,value0);\n"
" value0=mad((FLOAT4)value_a0.y,value_b1,value0);\n"
" value0=mad((FLOAT4)value_a0.z,value_b2,value0);\n"
" value0=mad((FLOAT4)value_a0.w,value_b3,value0);\n"
" \n"
" value1=mad((FLOAT4)value_a1.x,value_b0,value1);\n"
" value1=mad((FLOAT4)value_a1.y,value_b1,value1);\n"
" value1=mad((FLOAT4)value_a1.z,value_b2,value1);\n"
" value1=mad((FLOAT4)value_a1.w,value_b3,value1);\n"
" \n"
" value2=mad((FLOAT4)value_a2.x,value_b0,value2);\n"
" value2=mad((FLOAT4)value_a2.y,value_b1,value2);\n"
" value2=mad((FLOAT4)value_a2.z,value_b2,value2);\n"
" value2=mad((FLOAT4)value_a2.w,value_b3,value2);\n"
" \n"
" value3=mad((FLOAT4)value_a3.x,value_b0,value3);\n"
" value3=mad((FLOAT4)value_a3.y,value_b1,value3);\n"
" value3=mad((FLOAT4)value_a3.z,value_b2,value3);\n"
" value3=mad((FLOAT4)value_a3.w,value_b3,value3);\n"
"#endif\n"
" }\n"
" for(int i=((l_pack-1) << 2); i<l; ++i){\n"
"#if TRANSPOSE_A\n"
" FLOAT4 value_a=vload4(0,A_ptr+i*e);\n"
"#else\n"
" FLOAT4 value_a;\n"
" value_a.x=A_ptr[i];\n"
" value_a.y=A_ptr[i+l];\n"
" value_a.z=A_ptr[i+2*l];\n"
" value_a.w=A_ptr[i+3*l];\n"
"#endif\n"
"#if TRANSPOSE_B\n"
" FLOAT4 value_b;\n"
" value_b.x=B_ptr[i];\n"
" value_b.y=B_ptr[i+l];\n"
" value_b.z=B_ptr[i+2*l];\n"
" value_b.w=B_ptr[i+3*l];\n"
"#else\n"
" FLOAT4 value_b=vload4(0,B_ptr+i*h);\n"
"#endif\n"
" value0=mad((FLOAT4)value_a.x,value_b,value0);\n"
" value1=mad((FLOAT4)value_a.y,value_b,value1);\n"
" value2=mad((FLOAT4)value_a.z,value_b,value2);\n"
" value3=mad((FLOAT4)value_a.w,value_b,value3);\n"
" }\n"
" \n"
" const int output_offset=offset.x+pos.y*h+pos.x;\n"
"#if H_LEAVES == 0\n"
" vstore4(value0,0,output+output_offset);\n"
" if(pos.y+1 >= e) return;\n"
" vstore4(value1,0,output+output_offset+h);\n"
" if(pos.y+2 >= e) return;\n"
" vstore4(value2,0,output+output_offset+2*h);\n"
" if(pos.y+3 >= e) return;\n"
" vstore4(value3,0,output+output_offset+3*h);\n"
"#else\n"
" if(pos.x+3<h){\n"
" vstore4(value0,0,output+output_offset);\n"
" if(pos.y+1 >= e) return;\n"
" vstore4(value1,0,output+output_offset+h);\n"
" if(pos.y+2 >= e) return;\n"
" vstore4(value2,0,output+output_offset+2*h);\n"
" if(pos.y+3 >= e) return;\n"
" vstore4(value3,0,output+output_offset+3*h);\n"
" }else{\n"
"#if H_LEAVES == 1\n"
" output[output_offset]=value0.x;\n"
" if(pos.y+1 >= e) return;\n"
" output[output_offset+h]=value1.x;\n"
" if(pos.y+2 >= e) return;\n"
" output[output_offset+2*h]=value2.x;\n"
" if(pos.y+3 >= e) return;\n"
" output[output_offset+3*h]=value3.x;\n"
"#elif H_LEAVES == 2\n"
" vstore2((FLOAT2)value0.xy,0,output+output_offset);\n"
" if(pos.y+1 >= e) return;\n"
" vstore2((FLOAT2)value1.xy,0,output+output_offset+h);\n"
" if(pos.y+2 >= e) return;\n"
" vstore2((FLOAT2)value2.xy,0,output+output_offset+2*h);\n"
" if(pos.y+3 >= e) return;\n"
" vstore2((FLOAT2)value3.xy,0,output+output_offset+3*h);\n"
"#elif H_LEAVES == 3\n"
" vstore3((FLOAT3)value0.xyz,0,output+output_offset);\n"
" if(pos.y+1 >= e) return;\n"
" vstore3((FLOAT3)value1.xyz,0,output+output_offset+h);\n"
" if(pos.y+2 >= e) return;\n"
" vstore3((FLOAT3)value2.xyz,0,output+output_offset+2*h);\n"
" if(pos.y+3 >= e) return;\n"
" vstore3((FLOAT3)value3.xyz,0,output+output_offset+3*h);\n"
"#endif\n"
" }\n"
"#endif\n"
" }\n"
"}\n"
"__kernel void tile(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __read_only image2d_t input,\n"
" __global OUTPUT_TYPE* output,\n"
" __private const int width,\n"
" __private const int height,\n"
" __private const int channel){\n"
" int3 pos=(int3)(get_global_id(0),get_global_id(1),get_global_id(2));\n"
" if (pos.x<global_dim0 && pos.y<global_dim1 && pos.z<global_dim2) {\n"
" const int w=pos.x % width;\n"
" const int h=pos.x/width;\n"
" const int c=pos.y << 2;\n"
"#ifdef MNN_NHWC\n"
" const int c_dst_pitch=1;\n"
" const int x_dst_pitch=c_dst_pitch*channel;\n"
" const int y_dst_pitch=x_dst_pitch*width;\n"
" const int b_dst_pitch=y_dst_pitch*height;\n"
"#else\n"
" const int x_dst_pitch=1;\n"
" const int y_dst_pitch=x_dst_pitch*width;\n"
" const int c_dst_pitch=y_dst_pitch*height;\n"
" const int b_dst_pitch=c_dst_pitch*channel;\n"
"#endif\n"
" __global OUTPUT_TYPE* dst_ptr=output+pos.z*b_dst_pitch+c*c_dst_pitch+h*y_dst_pitch+w*x_dst_pitch;\n"
" \n"
" OUTPUT_TYPE4 value=CONVERT_OUTPUT4(RI_DATA(input,SAMPLER,(int2)(pos.y*width+w,pos.z*height+h)));\n"
" dst_ptr[0]=value.x;\n"
" if(c+1 >= channel)return;\n"
" dst_ptr[c_dst_pitch]=value.y;\n"
" if(c+2 >= channel)return;\n"
" dst_ptr[2*c_dst_pitch]=value.z;\n"
" if(c+3 >= channel)return;\n"
" dst_ptr[3*c_dst_pitch]=value.w;\n"
" }\n"
"}\n"
"__kernel void pack(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __global INPUT_TYPE* input,\n"
" __write_only image2d_t output,\n"
" __private const int width,\n"
" __private const int height,\n"
" __private const int channel){\n"
" int3 pos=(int3)(get_global_id(0),get_global_id(1),get_global_id(2));\n"
" if (pos.x<global_dim0 && pos.y<global_dim1 && pos.z<global_dim2) {\n"
" const int w=pos.x % width;\n"
" const int h=pos.x/width;\n"
" const int c=pos.y << 2;\n"
"#ifdef MNN_NHWC\n"
" const int c_src_pitch=1;\n"
" const int x_src_pitch=c_src_pitch*channel;\n"
" const int y_src_pitch=x_src_pitch*width;\n"
" const int b_src_pitch=y_src_pitch*height;\n"
"#else\n"
" const int x_src_pitch=1;\n"
" const int y_src_pitch=x_src_pitch*width;\n"
" const int c_src_pitch=y_src_pitch*height;\n"
" const int b_src_pitch=c_src_pitch*channel;\n"
"#endif\n"
" __global INPUT_TYPE* src_ptr=input+pos.z*b_src_pitch+c*c_src_pitch+h*y_src_pitch+w*x_src_pitch;\n"
" OUTPUT_TYPE_I4 value=(OUTPUT_TYPE_I4)0;\n"
" OUTPUT_TYPE_I *value_ptr=(OUTPUT_TYPE_I*)&value;\n"
" for(int i=0; i<4 && (i+c<channel); ++i){\n"
" value_ptr[i]=(OUTPUT_TYPE_I)src_ptr[i*c_src_pitch];\n"
" }\n"
" WI_DATA(output,(int2)(pos.y*width+w,pos.z*height+h),value);\n"
" }\n"
"}\n"
"__kernel void batch_gather(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __global OUTPUT_TYPE* output,__global INPUT_TYPE* input,\n"
" __global int* offset_dst,__global int* offset_src,\n"
" __private const int x_size,\n"
" __private const int4 stride_src,\n"
" __private const int4 stride_dst,\n"
" __private const int2 steps,\n"
" __private const int2 iters,\n"
" __private const int inputSize) {\n"
" int3 pos=(int3)(get_global_id(0),get_global_id(1),get_global_id(2));\n"
" \n"
" if (pos.x<global_dim0 && pos.y<global_dim1 && pos.z<global_dim2) {\n"
" \n"
" int x=pos.x % x_size;\n"
" int y=pos.x/x_size;\n"
" int2 index=(int2)(pos.z,pos.z);\n"
" if (iters.x >= 0) {\n"
" index.x=offset_dst[pos.z];\n"
" }\n"
" if (iters.y >= 0) {\n"
" index.y=offset_src[pos.z];\n"
" }\n"
" int2 offset=index*steps;\n"
" if(offset.x >= 0){\n"
" if(offset.y >= 0 && offset.y<inputSize){\n"
" output[offset.x+stride_dst.w+x*stride_dst.x+y*stride_dst.y+pos.y*stride_dst.z]=(OUTPUT_TYPE)input[offset.y+stride_src.w+x*stride_src.x+y*stride_src.y+pos.y*stride_src.z];\n"
" }else{\n"
" output[offset.x+stride_dst.w+x*stride_dst.x+y*stride_dst.y+pos.y*stride_dst.z]=(OUTPUT_TYPE)(0);\n"
" }\n"
" }\n"
" }\n"
"}\n"
"#ifdef LOOP_BINARY_OPERATOR\n"
"__kernel void broadcast_binary(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __write_only image2d_t output,__read_only image2d_t input0,__read_only image2d_t input1,\n"
" __private const int8 src0_size,//(batch,channel,height,width)\n"
" __private const int4 src0C4_size,// nc4hw4\n"
" __private const int8 src1_size,\n"
" __private const int4 src1C4_size,\n"
" __private const int8 dst_size,\n"
" __private const int dst_width,\n"
" __private const int dst_height,\n"
" __private const int dst_channel,\n"
" __private const int channel_block) {\n"
" int3 pos=(int3)(get_global_id(0),get_global_id(1),get_global_id(2));\n"
" \n"
" if (pos.x<global_dim0 && pos.y<global_dim1 && pos.z<global_dim2) {\n"
" \n"
" const int wo=pos.x;\n"
" const int ho=pos.y;\n"
" const int co=pos.z % channel_block;\n"
" const int no=pos.z/channel_block;\n"
" int co4=co << 2;\n"
" int4 covec=(int4)(co4 % dst_channel,(co4+1) % dst_channel,(co4+2) % dst_channel,(co4+3) % dst_channel);\n"
" int4 out_offset=((no*dst_channel+covec)*dst_height+ho)*dst_width+wo;\n"
" int4 w=out_offset % (dst_size.s3*dst_size.s4); out_offset /= (dst_size.s3*dst_size.s4);\n"
" int4 h=out_offset % dst_size.s2; out_offset /= dst_size.s2;\n"
" int4 c=out_offset % dst_size.s1; out_offset /= dst_size.s1;\n"
" int4 n=out_offset % dst_size.s0;\n"
" float4 in0,in1;\n"
" float* in0_ptr=(float*)&in0;\n"
" float* in1_ptr=(float*)&in1;\n"
" \n"
" {\n"
" int4 w0=w % (src0_size.s3*src0_size.s4);\n"
" int4 h0=h % src0_size.s2;\n"
" int4 c0=c % src0_size.s1;\n"
" int4 n0=n % src0_size.s0;\n"
" int* w0_ptr=(int*)&w0;\n"
" int* h0_ptr=(int*)&h0;\n"
" int* c0_ptr=(int*)&c0;\n"
" int* n0_ptr=(int*)&n0;\n"
" for(int i=0; i<4; ++i){\n"
" int c4offset=((n0_ptr[i]*src0_size.s1+c0_ptr[i])*src0_size.s2+h0_ptr[i])*src0_size.s3*src0_size.s4+w0_ptr[i];\n"
" int wc4=c4offset % src0C4_size.x; c4offset /= src0C4_size.x;\n"
" int hc4=c4offset % src0C4_size.y; c4offset /= src0C4_size.y;\n"
" int cc4=c4offset % src0C4_size.z; c4offset /= src0C4_size.z;\n"
" int nc4=c4offset % src0C4_size.w;\n"
" int cc4_offset=cc4/4;\n"
" int cc4_remain=cc4 % 4;\n"
" float4 tmp=convert_float4(RI_DATA(input0,SAMPLER,(int2)(cc4_offset*src0C4_size.x+wc4,nc4*src0C4_size.y+hc4)));\n"
" float *tmp_ptr=(float*)&tmp;\n"
" in0_ptr[i]=tmp_ptr[cc4_remain];\n"
" }\n"
" }\n"
" \n"
" {\n"
" int4 w0=w % (src1_size.s3*src1_size.s4);\n"
" int4 h0=h % src1_size.s2;\n"
" int4 c0=c % src1_size.s1;\n"
" int4 n0=n % src1_size.s0;\n"
" int* w0_ptr=(int*)&w0;\n"
" int* h0_ptr=(int*)&h0;\n"
" int* c0_ptr=(int*)&c0;\n"
" int* n0_ptr=(int*)&n0;\n"
" for(int i=0; i<4; ++i){\n"
" int c4offset=((n0_ptr[i]*src1_size.s1+c0_ptr[i])*src1_size.s2+h0_ptr[i])*src1_size.s3*src1_size.s4+w0_ptr[i];\n"
" int wc4=c4offset % src1C4_size.x; c4offset /= src1C4_size.x;\n"
" int hc4=c4offset % src1C4_size.y; c4offset /= src1C4_size.y;\n"
" int cc4=c4offset % src1C4_size.z; c4offset /= src1C4_size.z;\n"
" int nc4=c4offset % src1C4_size.w;\n"
" int cc4_offset=cc4/4;\n"
" int cc4_remain=cc4 % 4;\n"
" float4 tmp=convert_float4(RI_DATA(input1,SAMPLER,(int2)(cc4_offset*src1C4_size.x+wc4,nc4*src1C4_size.y+hc4)));\n"
" float *tmp_ptr=(float*)&tmp;\n"
" in1_ptr[i]=tmp_ptr[cc4_remain];\n"
" }\n"
" }\n"
" \n"
" float4 out=LOOP_BINARY_OPERATOR;\n"
" WI_DATA(output,(int2)(co*dst_width+wo,no*dst_height+ho),CONVERT_OUTPUT_I4(out));\n"
" }\n"
"}\n"
"#endif\n"
;
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* argmax_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_3_DIMS ""__private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"#define ARGMAX_SELECT(A, B, C, D) "" if(A.x < B.x){ A.x = B.x; C.x = D; } "" if(A.y < B.y){ A.y = B.y; C.y = D; } "" if(A.z < B.z){ A.z = B.z; C.z = D; } "" if(A.w<B.w){ A.w=B.w; C.w=D; } \n"
"#define ARGMIN_SELECT(A, B, C, D) "" if(A.x > B.x){ A.x = B.x; C.x = D; } "" if(A.y > B.y){ A.y = B.y; C.y = D; } "" if(A.z > B.z){ A.z = B.z; C.z = D; } "" if(A.w>B.w){ A.w=B.w; C.w=D; } \n"
"__kernel void argmax_buf(GLOBAL_SIZE_3_DIMS\n"
" __global const FLOAT* input,\n"
" __global int* output,\n"
" __private const int inside,\n"
" __private const int outside,\n"
" __private const int dim){\n"
" const int x=get_global_id(0);\n"
" const int y=get_global_id(1); // inside\n"
" const int z=get_global_id(2); // outside\n"
" \n"
" DEAL_NON_UNIFORM_DIM3(x,y,z);\n"
" int index=0;\n"
"#ifdef ARGMAX\n"
" FLOAT maxValue=(FLOAT)-FLT_MAX;\n"
"#else\n"
"FLOAT maxValue=(FLOAT)FLT_MAX;\n"
"#endif\n"
" const int offset=z*dim*inside+y;\n"
"#if ARGMAX_LOCAL_SIZE >= 4\n"
" int lid=get_local_id(0);\n"
" FLOAT local reduce[ARGMAX_LOCAL_SIZE];\n"
" int local index_reduce[ARGMAX_LOCAL_SIZE];\n"
" \n"
" for (int i=lid; i<dim; i+=ARGMAX_LOCAL_SIZE) {\n"
" FLOAT value=input[offset+i*inside];\n"
"#ifdef ARGMAX\n"
" if(maxValue<value){ maxValue=value; index=i; }\n"
"#else\n"
" if(maxValue>value){ maxValue=value; index=i; }\n"
"#endif\n"
" }\n"
" reduce[lid]=maxValue;\n"
" index_reduce[lid]=index;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=ARGMAX_LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i){\n"
"#ifdef ARGMAX\n"
" if(reduce[lid]<reduce[lid+i]){reduce[lid]=reduce[lid+i]; index_reduce[lid]=index_reduce[lid+i];}\n"
"#else\n"
" if(reduce[lid]>reduce[lid+i]){reduce[lid]=reduce[lid+i]; index_reduce[lid]=index_reduce[lid+i];}\n"
"#endif\n"
" }\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" if(lid == 0){\n"
" output[z*inside+y]=index_reduce[0];\n"
" }\n"
"#else\n"
" for(int i=0; i<dim; ++i){\n"
" FLOAT value=input[+offset+i*inside];\n"
"#ifdef ARGMAX\n"
" if(maxValue<value){ maxValue=value; index=i; }\n"
"#else\n"
" if(maxValue>value){ maxValue=value; index=i; }\n"
"#endif\n"
" }\n"
" output[z*inside+y]=index;\n"
"#endif\n"
"}\n"
"__kernel void argmax_v4_buf(GLOBAL_SIZE_3_DIMS\n"
" __global const FLOAT* input,\n"
" __global int* output,\n"
" __private const int inside,\n"
" __private const int outside,\n"
" __private const int dim){\n"
" const int x=get_global_id(0);\n"
" const int y=get_global_id(1) << 2; // inside\n"
" const int z=get_global_id(2); // outside\n"
" \n"
" DEAL_NON_UNIFORM_DIM3(x,y,z);\n"
" int4 index=0;\n"
"#ifdef ARGMAX\n"
" FLOAT4 maxValue=(FLOAT4)-FLT_MAX;\n"
"#else\n"
" FLOAT4 maxValue=(FLOAT4)FLT_MAX;\n"
"#endif\n"
" const int offset=z*dim*inside+y;\n"
"#if ARGMAX_LOCAL_SIZE >= 4\n"
" int lid=get_local_id(0);\n"
" FLOAT4 local reduce[ARGMAX_LOCAL_SIZE];\n"
" int4 local index_reduce[ARGMAX_LOCAL_SIZE];\n"
" \n"
" for (int i=lid; i<dim; i+=ARGMAX_LOCAL_SIZE) {\n"
" FLOAT4 value=vload4(0,input+offset+i*inside);\n"
"#ifdef ARGMAX\n"
" ARGMAX_SELECT(maxValue,value,index,i);\n"
"#else\n"
" ARGMIN_SELECT(maxValue,value,index,i);\n"
"#endif\n"
" }\n"
" reduce[lid]=maxValue;\n"
" index_reduce[lid]=index;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=ARGMAX_LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i){\n"
"#ifdef ARGMAX\n"
" if(reduce[lid].x<reduce[lid+i].x){reduce[lid].x=reduce[lid+i].x; index_reduce[lid].x=index_reduce[lid+i].x;}\n"
" if(reduce[lid].y<reduce[lid+i].y){reduce[lid].y=reduce[lid+i].y; index_reduce[lid].y=index_reduce[lid+i].y;}\n"
" if(reduce[lid].z<reduce[lid+i].z){reduce[lid].z=reduce[lid+i].z; index_reduce[lid].z=index_reduce[lid+i].z;}\n"
" if(reduce[lid].w<reduce[lid+i].w){reduce[lid].w=reduce[lid+i].w; index_reduce[lid].w=index_reduce[lid+i].w;}\n"
"#else\n"
" if(reduce[lid].x>reduce[lid+i].x){reduce[lid].x=reduce[lid+i].x; index_reduce[lid].x=index_reduce[lid+i].x;}\n"
" if(reduce[lid].y>reduce[lid+i].y){reduce[lid].y=reduce[lid+i].y; index_reduce[lid].y=index_reduce[lid+i].y;}\n"
" if(reduce[lid].z>reduce[lid+i].z){reduce[lid].z=reduce[lid+i].z; index_reduce[lid].z=index_reduce[lid+i].z;}\n"
" if(reduce[lid].w>reduce[lid+i].w){reduce[lid].w=reduce[lid+i].w; index_reduce[lid].w=index_reduce[lid+i].w;}\n"
"#endif\n"
" }\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" if(lid == 0){\n"
" vstore4(index_reduce[0],0,output+z*inside+y);\n"
" }\n"
"#else\n"
" for(int i=0; i<dim; ++i){\n"
" FLOAT4 value=vload4(0,input+offset+i*inside);\n"
"#ifdef ARGMAX\n"
" ARGMAX_SELECT(maxValue,value,index,i);\n"
"#else\n"
" ARGMIN_SELECT(maxValue,value,index,i);\n"
"#endif\n"
" }\n"
" vstore4(index,0,output+z*inside+y);\n"
"#endif\n"
"}\n"
;
#endif
#ifndef MNN_OPENCL_BUFFER_CLOSED
#ifdef MNN_SUPPORT_INTEL_SUBGROUP
const char* buffer_convert_subgroup_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0,__private const int global_size_dim1,\n"
"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n"
"// convert data from buffer(nhwc) to buffer(nc16hw16) float input\n"
"__kernel void nhwc_buffer_to_nc16hw16_buffer(GLOBAL_SIZE_2_DIMS\n"
" __global const INPUT_TYPE *input_ptr,\n"
" __private const int height,\n"
" __private const int width,__private const int channels,\n"
" __global OUTPUT_TYPE *output,\n"
" __private const int input_pad_left,__private const int input_pad_right,\n"
" __private const int output_pad_left,__private const int output_pad_right) {\n"
" int image_width_idx=get_global_id(0);\n"
" int image_height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n"
" const int batch_idx=image_height_idx/height;\n"
" const int height_idx=image_height_idx % height;\n"
" const int width_idx=image_width_idx % width;\n"
" const int channel_16_idx=(image_width_idx/width) << 4;\n"
" const int buffer_offset=((batch_idx*height+height_idx)*width+width_idx)*channels+channel_16_idx;\n"
" const int remain_channel=min(channels-channel_16_idx,16);\n"
" INPUT_TYPE16 values=0;\n"
" INPUT_TYPE* values_ptr=(INPUT_TYPE*)(&values);\n"
" __global const INPUT_TYPE *input_current_ptr=input_ptr+buffer_offset;\n"
" for(int i=0; i<remain_channel; ++i){\n"
" values_ptr[i]=*(input_current_ptr+i);\n"
" }\n"
" const int out_offset=(((batch_idx*((channels+15)/16)+channel_16_idx/16)*height+height_idx)*(output_pad_left+width+output_pad_right)+width_idx+output_pad_left)*16;\n"
" vstore16(CONVERT_OUTPUT16(values),0,output+out_offset);\n"
" if(width_idx == 0){\n"
" int pad_offset=(((batch_idx*((channels+15)/16)+channel_16_idx/16)*height+height_idx)*(output_pad_left+width+output_pad_right))*16;\n"
" for(int i=0; i<output_pad_left; ++i){\n"
" vstore16((OUTPUT_TYPE16)0,0,output+pad_offset+i*16);\n"
" }\n"
" pad_offset += (output_pad_right+width)*16;\n"
" for(int i=0; i<output_pad_right; ++i){\n"
" vstore16((OUTPUT_TYPE16)0,0,output+pad_offset+i*16);\n"
" }\n"
" }\n"
"}\n"
"// convert data from buffer(nchw) to buffer(nc16hw16)\n"
"__kernel void nchw_buffer_to_nc16hw16_buffer(GLOBAL_SIZE_2_DIMS\n"
" __global const INPUT_TYPE *input_ptr,\n"
" __private const int height,__private const int width,__private const int channels,\n"
" __global OUTPUT_TYPE *output,\n"
" __private const int input_pad_left,__private const int input_pad_right,\n"
" __private const int output_pad_left,__private const int output_pad_right) {\n"
" int image_width_idx=get_global_id(0);\n"
" int image_height_idx=get_global_id(1);\n"
" \n"
" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n"
" const int src_width=width+input_pad_left+input_pad_right;\n"
" const int dst_width=width+output_pad_left+output_pad_right;\n"
" const int batch_idx=image_height_idx/height;\n"
" const int height_idx=image_height_idx % height;\n"
" const int width_idx=image_width_idx % width;\n"
" const int channel_16_idx=image_width_idx/width << 4;\n"
" const int buffer_offset=((batch_idx*channels+channel_16_idx)*height+height_idx)*src_width+width_idx+input_pad_left;\n"
" const int remain_channel=min(channels-channel_16_idx,16);\n"
" const int height_width_size=height*width;\n"
" INPUT_TYPE16 output_values=0;\n"
" INPUT_TYPE *output_values_ptr=(INPUT_TYPE*)(&output_values);\n"
" for(int i=0; i<remain_channel; ++i){\n"
" output_values_ptr[i]=*(input_ptr+buffer_offset+height_width_size*i);\n"
" }\n"
" if(width_idx == 0){\n"
" int pad_offset=(((batch_idx*((channels+15)/16)+channel_16_idx/16)*height+height_idx)*dst_width+0)*16;\n"
" for(int i=0; i<output_pad_left; ++i){\n"
" vstore16((OUTPUT_TYPE16)0,0,output+pad_offset+16*i);\n"
" }\n"
" pad_offset += 16*(width+output_pad_left);\n"
" for(int i=0; i<output_pad_right; ++i){\n"
" vstore16((OUTPUT_TYPE16)0,0,output+pad_offset+16*i);\n"
" }\n"
" }\n"
" const int out_offset=(((batch_idx*((channels+15)/16)+channel_16_idx/16)*height+height_idx)*dst_width+width_idx+output_pad_left)*16;\n"
" vstore16(CONVERT_OUTPUT16(output_values),0,output+out_offset);\n"
"}\n"
"// convert data from image(b h,ic/16 w ic16) to buffer(nhwc)\n"
"__kernel void nc16hw16_buffer_to_nhwc_buffer(GLOBAL_SIZE_2_DIMS\n"
" __global OUTPUT_TYPE *output,\n"
" __private const int height,__private const int width,\n"
" __private const int channels,\n"
" __global INPUT_TYPE *input_ptr,\n"
" __private const int input_pad_left,__private const int input_pad_right,\n"
" __private const int output_pad_left,__private const int output_pad_right) {\n"
" int image_width_idx=get_global_id(0);\n"
" int image_height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n"
" const int batch_idx=image_height_idx/height;\n"
" const int height_idx=image_height_idx % height;\n"
" const int width_idx=image_width_idx % width;\n"
" const int channel_16_idx=(image_width_idx/width) << 4;\n"
" const int buffer_offset=((batch_idx*height+height_idx)*width+width_idx)*channels+channel_16_idx;\n"
" const int in_offset=(((batch_idx*((channels+15)/16)+channel_16_idx/16)*height+height_idx)*(input_pad_left+width+input_pad_right)+width_idx+input_pad_left)*16;\n"
" INPUT_TYPE16 values=vload16(0,input_ptr+in_offset);\n"
" INPUT_TYPE* values_ptr=(INPUT_TYPE*)(&values);\n"
" const int remain_channel=min(channels-channel_16_idx,16);\n"
" for(int i=0; i<remain_channel; ++i){\n"
" output[buffer_offset+i]=(OUTPUT_TYPE)values_ptr[i];\n"
" }\n"
"}\n"
"// convert data from buffer(nc16hw16) to buffer(nchw)\n"
"__kernel void nc16hw16_buffer_to_nchw_buffer(GLOBAL_SIZE_2_DIMS\n"
" __global OUTPUT_TYPE *output,\n"
" __private const int height,__private const int width,\n"
" __private const int channels,\n"
" __global INPUT_TYPE *input_ptr,\n"
" __private const int input_pad_left,__private const int input_pad_right,\n"
" __private const int output_pad_left,__private const int output_pad_right) {\n"
" int image_width_idx=get_global_id(0);\n"
" int image_height_idx=get_global_id(1);\n"
" \n"
" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n"
" \n"
" const int src_width=width+input_pad_left+input_pad_right;\n"
" const int batch_idx=image_height_idx/height;\n"
" const int height_idx=image_height_idx % height;\n"
" const int width_idx=image_width_idx % width;\n"
" int channel_16_idx=(image_width_idx/width) << 4;\n"
" int buffer_offset=((batch_idx*channels+channel_16_idx)*height+height_idx)*width+width_idx;\n"
" \n"
" const int in_offset=(((batch_idx*((channels+15)/16)+channel_16_idx/16)*height+height_idx)*src_width+width_idx+input_pad_left)*16;\n"
" INPUT_TYPE16 values=vload16(0,input_ptr+in_offset);\n"
" INPUT_TYPE *values_ptr=(INPUT_TYPE*)(&values);\n"
" const int height_width_size=height*width;\n"
" const int remain_channel=min(channels-channel_16_idx,16);\n"
" for(int i=0; i<remain_channel; ++i){\n"
" output[buffer_offset+i*height_width_size]=(OUTPUT_TYPE)values_ptr[i];\n"
" }\n"
"}\n"
"__kernel void nc4hw4_buffer_to_nc16hw16_buffer(GLOBAL_SIZE_2_DIMS\n"
" __global const INPUT_TYPE *input_ptr,\n"
" __private const int2 output_shape,\n"
" __private const int2 src_stride,\n"
" __private const int2 dst_stride,\n"
" __global OUTPUT_TYPE *output,\n"
" __private const int input_pad_left,\n"
" __private const int input_pad_right,\n"
" __private const int output_pad_left,\n"
" __private const int output_pad_right,\n"
" __private const int channelc4\n"
") {\n"
" int image_width_idx=get_global_id(0);\n"
" int image_height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n"
" const int batch_idx=image_height_idx/output_shape.x;\n"
" const int height_idx=image_height_idx % output_shape.x;\n"
" const int width_idx=image_width_idx % output_shape.y;\n"
" const int channel_block_idx=image_width_idx/output_shape.y;\n"
" const int in_channel_block_idx=channel_block_idx << 2;\n"
" const int dst_width=output_pad_left+output_shape.y+output_pad_right;\n"
" int2 src_bc_offset=src_stride*(int2)(batch_idx,in_channel_block_idx);\n"
" int2 dst_bc_offset=dst_stride*(int2)(batch_idx,channel_block_idx);\n"
" int src_buffer_offset =\n"
" (((src_bc_offset.x+src_bc_offset.y)*output_shape.x+height_idx)*output_shape.y+width_idx)*4;\n"
" int dst_buffer_offset =\n"
" (((dst_bc_offset.x+dst_bc_offset.y)*output_shape.x+height_idx)*dst_width+width_idx+output_pad_left)*16;\n"
" int width_height_size4=output_shape.x*output_shape.y*4;\n"
" INPUT_TYPE4 values0=vload4(0,input_ptr+src_buffer_offset);\n"
" INPUT_TYPE4 values1=in_channel_block_idx+1 >= src_bc_offset.x ? (INPUT_TYPE4)0 : vload4(0,input_ptr+src_buffer_offset+width_height_size4);\n"
" INPUT_TYPE4 values2=in_channel_block_idx+2 >= src_bc_offset.x ? (INPUT_TYPE4)0 : vload4(0,input_ptr+src_buffer_offset+width_height_size4*2);\n"
" INPUT_TYPE4 values3=in_channel_block_idx+3 >= src_bc_offset.x ? (INPUT_TYPE4)0 : vload4(0,input_ptr+src_buffer_offset+width_height_size4*3);\n"
" \n"
" vstore16(CONVERT_OUTPUT16((INPUT_TYPE16)(values0.s0123,values1.s0123,values2.s0123,values3.s0123)),0,output+dst_buffer_offset);\n"
" if(width_idx == 0){\n"
" int pad_offset=(((dst_bc_offset.x+dst_bc_offset.y)*output_shape.x+height_idx)*dst_width)*16;\n"
" for(int i=0; i<output_pad_left; ++i){\n"
" vstore16((OUTPUT_TYPE16)0,0,output+pad_offset+16*i);\n"
" }\n"
" pad_offset += 16*(output_shape.y+output_pad_left);\n"
" for(int i=0; i<output_pad_right; ++i){\n"
" vstore16((OUTPUT_TYPE16)0,0,output+pad_offset+16*i);\n"
" }\n"
" }\n"
"}\n"
"__kernel void nc16hw16_buffer_to_nc4hw4_buffer(GLOBAL_SIZE_2_DIMS\n"
" __global const INPUT_TYPE *input_ptr,\n"
" __private const int2 output_shape,\n"
" __private const int2 src_stride,\n"
" __private const int2 dst_stride,\n"
" __global OUTPUT_TYPE *output,\n"
" __private const int input_pad_left,__private const int input_pad_right,\n"
" __private const int output_pad_left,__private const int output_pad_right,\n"
" __private const int channelc4\n"
") {\n"
" int image_width_idx=get_global_id(0);\n"
" int image_height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n"
" const int batch_idx=image_height_idx/output_shape.x;\n"
" const int height_idx=image_height_idx % output_shape.x;\n"
" const int width_idx=image_width_idx % output_shape.y;\n"
" const int channel_block_idx=image_width_idx/output_shape.y;\n"
" const int out_channel_block_idx=channel_block_idx << 2;\n"
" int2 src_bc_offset=src_stride*(int2)(batch_idx,channel_block_idx);\n"
" int2 dst_bc_offset=dst_stride*(int2)(batch_idx,out_channel_block_idx);\n"
" int width_height_size4=output_shape.x*output_shape.y*4;\n"
" int src_buffer_offset =\n"
" (((src_bc_offset.x+src_bc_offset.y)*output_shape.x+height_idx)*(input_pad_left+output_shape.y+input_pad_right)+width_idx+input_pad_left)*16;\n"
" int dst_buffer_offset =\n"
" (((dst_bc_offset.x+dst_bc_offset.y)*output_shape.x+height_idx)*output_shape.y+width_idx)*4;\n"
" INPUT_TYPE16 values=vload16(0,input_ptr+src_buffer_offset);\n"
" \n"
" vstore4(CONVERT_OUTPUT4(values.s0123),0,output+dst_buffer_offset);\n"
" if(out_channel_block_idx+1 >= channelc4) return;\n"
" vstore4(CONVERT_OUTPUT4(values.s4567),0,output+dst_buffer_offset+width_height_size4);\n"
" if(out_channel_block_idx+2 >= channelc4) return;\n"
" vstore4(CONVERT_OUTPUT4(values.s89ab),0,output+dst_buffer_offset+2*width_height_size4);\n"
" if(out_channel_block_idx+3 >= channelc4) return;\n"
" vstore4(CONVERT_OUTPUT4(values.scdef),0,output+dst_buffer_offset+3*width_height_size4);\n"
"}\n"
;
#endif
#endif
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* attention_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"#define DEAL_OUTER_SEQLEN_NOT_ALIGN(length) "" if(4 * sl + 3 >= length) {"" temp_3 = (FLOAT4)0;"" }"" if(4 * sl + 2 >= length) {"" temp_2 = (FLOAT4)0;"" }"" if(4 * sl + 1 >= length) {"" temp_1 = (FLOAT4)0;"" }\n"
"#define DEAL_INNER_HEADDIM_NOT_ALIGN(length) "" if(hd * 4 + 3 >= length) {"" temp_0.w = (FLOAT)0;"" temp_1.w = (FLOAT)0;"" temp_2.w = (FLOAT)0;"" temp_3.w = (FLOAT)0;"" }"" if(hd * 4 + 2 >= length) {"" temp_0.z = (FLOAT)0;"" temp_1.z = (FLOAT)0;"" temp_2.z = (FLOAT)0;"" temp_3.z = (FLOAT)0;"" }"" if(hd * 4 + 1 >= length) {"" temp_0.y = (FLOAT)0;"" temp_1.y = (FLOAT)0;"" temp_2.y = (FLOAT)0;"" temp_3.y = (FLOAT)0;"" }\n"
"__kernel void rearrange_qkv(GLOBAL_SIZE_3_DIMS\n"
" __global const FLOAT *input_q,//[batch,seqLenQ/4,headNum,headDim,seqLenQ_4]\n"
" __global const FLOAT *input_k,// [batch,seqLenKV/4,headNum/group,headDim,seqLenKV_4]\n"
" __global const FLOAT *input_v,// [batch,seqLenKV/4,headNum/group,headDim,seqLenKV_4]\n"
" __global FLOAT *output_q,// [batch*headNum,ROUND_UP(headDim,mTileHDK),ROUND_UP(seqLenQ,mTileQ)]\n"
" __global FLOAT *output_k,// [batch*headNum/group,ROUND_UP(headDim,mTileHDK),ROUND_UP(seqLenKV,mTileKV)]\n"
" __global FLOAT *output_v,// [batch*headNum/group,ROUND_UP(seqLenKV,mTileKV),ROUND_UP(headDim,mTileHDN)]\n"
" __global FLOAT *past_k,// [batch,seqLenKV/4,headNum/group,headDim,seqLenKV_4]\n"
" __global FLOAT *past_v,// [batch,seqLenKV/4,headNum/group,headDim,seqLenKV_4]\n"
" __private const int4 tile,// [mTileQ,mTileKV,mTileHDK,mTileHDN]\n"
" __private const int4 shape,// [seqLenQ,seqLenKV,headNum,headDim]\n"
" __private const int4 param // [group,batch]\n"
") {\n"
" const int sl=get_global_id(0); // seqLen/4 : max(seqLenPackQ/4,seqLenPackKV/4)\n"
" const int hd=get_global_id(1); // headDim/4 : max(headDimPackQK/4,headDimPackV/4)\n"
" const int z=get_global_id(2); // batch*headNum\n"
" DEAL_NON_UNIFORM_DIM3(sl,hd,z);\n"
" \n"
" const int seqLenQ=shape.x;\n"
" const int seqLenKV=shape.y;\n"
" const int headNum=shape.z;\n"
" const int headDim=shape.w;\n"
" const int group=param.x;\n"
" const int batch=param.y;\n"
" const int b=z % batch;\n"
" const int hn=z/batch;\n"
" \n"
" const int seqLenQ_4=(seqLenQ+3)/4;\n"
" //const int in_offset_q=(((b*seqLenQ_4+sl)*headNum+hn)*headDim+4*hd)*4;\n"
" const int in_offset_q=(((b*seqLenQ+sl*4)*headNum+hn)*headDim+4*hd);\n"
" const int seqLenPackQ=((seqLenQ+tile.x-1)/tile.x)*tile.x;\n"
" const int headDimPackQK=((headDim+tile.z-1)/tile.z)*tile.z;\n"
" const int out_offset_q=(((b*headNum+hn)*headDimPackQK+hd*4)*seqLenPackQ+sl*4);\n"
" \n"
" if(sl*4<seqLenPackQ && hd*4<headDimPackQK) {\n"
" if(sl*4 >= seqLenQ || hd*4 >= headDim) {\n"
" vstore4((FLOAT4)0,0,output_q+out_offset_q);\n"
" vstore4((FLOAT4)0,0,output_q+out_offset_q+seqLenPackQ);\n"
" vstore4((FLOAT4)0,0,output_q+out_offset_q+2*seqLenPackQ);\n"
" vstore4((FLOAT4)0,0,output_q+out_offset_q+3*seqLenPackQ);\n"
" } else {\n"
" FLOAT4 temp_0=vload4(0,input_q+in_offset_q);\n"
" FLOAT4 temp_1=(sl*4+1 >= seqLenQ) ? (FLOAT4)0 : vload4(0,input_q+in_offset_q+headNum*headDim);\n"
" FLOAT4 temp_2=(sl*4+2 >= seqLenQ) ? (FLOAT4)0 : vload4(0,input_q+in_offset_q+2*headNum*headDim);\n"
" FLOAT4 temp_3=(sl*4+3 >= seqLenQ) ? (FLOAT4)0 : vload4(0,input_q+in_offset_q+3*headNum*headDim);\n"
" #ifdef HEADDIM_LEAVE\n"
" DEAL_INNER_HEADDIM_NOT_ALIGN(headDim)\n"
" #endif\n"
" #ifdef SEQLEN_LEAVE\n"
" DEAL_OUTER_SEQLEN_NOT_ALIGN(seqLenQ)\n"
" #endif\n"
" vstore4((FLOAT4)(temp_0.s0,temp_1.s0,temp_2.s0,temp_3.s0),0,output_q+out_offset_q);\n"
" vstore4((FLOAT4)(temp_0.s1,temp_1.s1,temp_2.s1,temp_3.s1),0,output_q+out_offset_q+seqLenPackQ);\n"
" vstore4((FLOAT4)(temp_0.s2,temp_1.s2,temp_2.s2,temp_3.s2),0,output_q+out_offset_q+2*seqLenPackQ);\n"
" vstore4((FLOAT4)(temp_0.s3,temp_1.s3,temp_2.s3,temp_3.s3),0,output_q+out_offset_q+3*seqLenPackQ);\n"
" }\n"
" }\n"
" \n"
" if(hn >= headNum/group) {\n"
" return;\n"
" }\n"
" \n"
" const int seqLenPackKV=((seqLenKV+tile.y-1)/tile.y)*tile.y;\n"
" const int headDimPackV=((headDim+tile.w-1)/tile.w)*tile.w;\n"
" const int seqLenKV_4=(seqLenKV+3)/4;\n"
" const int in_offset_kv=(((b*seqLenKV+sl*4)*headNum/group+hn)*headDim+4*hd);\n"
" \n"
" if(sl*4<seqLenPackKV && hd*4<headDimPackQK) {\n"
" const int out_offset_k=(((b*headNum/group+hn)*headDimPackQK+hd*4)*seqLenPackKV+sl*4);\n"
" if(sl*4 >= seqLenKV || hd*4 >= headDim) {\n"
" vstore4((FLOAT4)0,0,output_k+out_offset_k);\n"
" vstore4((FLOAT4)0,0,output_k+out_offset_k+seqLenPackKV);\n"
" vstore4((FLOAT4)0,0,output_k+out_offset_k+2*seqLenPackKV);\n"
" vstore4((FLOAT4)0,0,output_k+out_offset_k+3*seqLenPackKV);\n"
" } else {\n"
" FLOAT4 temp_0=vload4(0,input_k+in_offset_kv);\n"
" FLOAT4 temp_1=(sl*4+1 >= seqLenKV) ? (FLOAT4)0 : vload4(0,input_k+in_offset_kv+headNum*headDim/group);\n"
" FLOAT4 temp_2=(sl*4+2 >= seqLenKV) ? (FLOAT4)0 : vload4(0,input_k+in_offset_kv+2*headNum*headDim/group);\n"
" FLOAT4 temp_3=(sl*4+3 >= seqLenKV) ? (FLOAT4)0 : vload4(0,input_k+in_offset_kv+3*headNum*headDim/group);\n"
" #ifdef HEADDIM_LEAVE\n"
" DEAL_INNER_HEADDIM_NOT_ALIGN(headDim)\n"
" #endif\n"
" #ifdef SEQLEN_LEAVE\n"
" DEAL_OUTER_SEQLEN_NOT_ALIGN(seqLenKV)\n"
" #endif\n"
" vstore4((FLOAT4)(temp_0.s0,temp_1.s0,temp_2.s0,temp_3.s0),0,output_k+out_offset_k);\n"
" vstore4((FLOAT4)(temp_0.s1,temp_1.s1,temp_2.s1,temp_3.s1),0,output_k+out_offset_k+seqLenPackKV);\n"
" vstore4((FLOAT4)(temp_0.s2,temp_1.s2,temp_2.s2,temp_3.s2),0,output_k+out_offset_k+2*seqLenPackKV);\n"
" vstore4((FLOAT4)(temp_0.s3,temp_1.s3,temp_2.s3,temp_3.s3),0,output_k+out_offset_k+3*seqLenPackKV);\n"
" \n"
" // pastK\n"
" vstore4(temp_0,0,past_k+in_offset_kv);\n"
" if(sl*4+1<seqLenKV) {\n"
" vstore4(temp_1,0,past_k+in_offset_kv+headNum*headDim/group);\n"
" }\n"
" if(sl*4+2<seqLenKV) {\n"
" vstore4(temp_2,0,past_k+in_offset_kv+2*headNum*headDim/group);\n"
" }\n"
" if(sl*4+3<seqLenKV) {\n"
" vstore4(temp_3,0,past_k+in_offset_kv+3*headNum*headDim/group);\n"
" }\n"
" }\n"
" \n"
" }\n"
" \n"
" if(sl*4<seqLenPackKV && hd*4<headDimPackV) {\n"
" const int out_offset_v=(((b*headNum/group+hn)*seqLenPackKV+sl*4)*headDimPackV+hd*4);\n"
" if(sl*4 >= seqLenKV || hd*4 >= headDim) {\n"
" vstore4((FLOAT4)0,0,output_v+out_offset_v);\n"
" vstore4((FLOAT4)0,0,output_v+out_offset_v+headDimPackV);\n"
" vstore4((FLOAT4)0,0,output_v+out_offset_v+2*headDimPackV);\n"
" vstore4((FLOAT4)0,0,output_v+out_offset_v+3*headDimPackV);\n"
" } else {\n"
" FLOAT4 temp_0=vload4(0,input_v+in_offset_kv);\n"
" FLOAT4 temp_1=(sl*4+1 >= seqLenKV) ? (FLOAT4)0 : vload4(0,input_v+in_offset_kv+headNum*headDim/group);\n"
" FLOAT4 temp_2=(sl*4+2 >= seqLenKV) ? (FLOAT4)0 : vload4(0,input_v+in_offset_kv+2*headNum*headDim/group);\n"
" FLOAT4 temp_3=(sl*4+3 >= seqLenKV) ? (FLOAT4)0 : vload4(0,input_v+in_offset_kv+3*headNum*headDim/group);\n"
" #ifdef HEADDIM_LEAVE\n"
" DEAL_INNER_HEADDIM_NOT_ALIGN(headDim)\n"
" #endif\n"
" #ifdef SEQLEN_LEAVE\n"
" DEAL_OUTER_SEQLEN_NOT_ALIGN(seqLenKV)\n"
" #endif\n"
" vstore4(temp_0,0,output_v+out_offset_v);\n"
" vstore4(temp_1,0,output_v+out_offset_v+headDimPackV);\n"
" vstore4(temp_2,0,output_v+out_offset_v+2*headDimPackV);\n"
" vstore4(temp_3,0,output_v+out_offset_v+3*headDimPackV);\n"
" \n"
" // pastV\n"
" vstore4(temp_0,0,past_v+in_offset_kv);\n"
" if(sl*4+1<seqLenKV) {\n"
" vstore4(temp_1,0,past_v+in_offset_kv+headNum*headDim/group);\n"
" }\n"
" if(sl*4+2<seqLenKV) {\n"
" vstore4(temp_2,0,past_v+in_offset_kv+2*headNum*headDim/group);\n"
" }\n"
" if(sl*4+3<seqLenKV) {\n"
" vstore4(temp_3,0,past_v+in_offset_kv+3*headNum*headDim/group);\n"
" }\n"
" }\n"
" \n"
" }\n"
"}\n"
"#ifndef MASK_DTYPE\n"
"#define MASK_DTYPE FLOAT\n"
"#define MASK_DTYPE4 FLOAT4\n"
"#endif\n"
"__kernel void rearrange_mask(GLOBAL_SIZE_3_DIMS\n"
" __global const MASK_DTYPE *input_mask,// [batch,1,seqLenQ,seqLenKV,4]\n"
" __global MASK_DTYPE *output_mask,// [batch,ROUND_UP(seqLenQ,mTileQ),ROUND_UP(seqLenKV,mTileKV)]\n"
" const int4 shape // [seqLenQ,seqLenKV,mTileQ,mTileKV]\n"
") {\n"
" const int sl=get_global_id(0); // seqLen_4\n"
" const int sl_kv=get_global_id(1); // seqLenKV_4\n"
" const int b=get_global_id(2); // Batch\n"
" DEAL_NON_UNIFORM_DIM3(sl,sl_kv,b);\n"
" \n"
" const int seq_len_pack=((shape.x+shape.z-1)/shape.z)*shape.z;\n"
" const int seq_len_kv_pack=((shape.y+shape.w-1)/shape.w)*shape.w;\n"
" int in_offset=((b*shape.x+sl*4)*shape.y+sl_kv*4);\n"
" int out_offset=(b*seq_len_pack+sl*4)*seq_len_kv_pack+sl_kv*4;\n"
" if(sl*4 >= shape.x || sl_kv*4 >= shape.y) {\n"
" vstore4((MASK_DTYPE4)0,0,output_mask+out_offset);\n"
" vstore4((MASK_DTYPE4)0,0,output_mask+out_offset+seq_len_kv_pack);\n"
" vstore4((MASK_DTYPE4)0,0,output_mask+out_offset+seq_len_kv_pack*2);\n"
" vstore4((MASK_DTYPE4)0,0,output_mask+out_offset+seq_len_kv_pack*3);\n"
" } else {\n"
" int y_down_align4=(shape.y/4*4);\n"
" MASK_DTYPE4 temp_0,temp_1,temp_2,temp_3;\n"
" \n"
" if(sl_kv*4<y_down_align4) {\n"
" temp_0=vload4(0,input_mask+in_offset);\n"
" temp_1=(sl*4+1 >= shape.x) ? (MASK_DTYPE4)0 : vload4(0,input_mask+in_offset+shape.y);\n"
" temp_2=(sl*4+2 >= shape.x) ? (MASK_DTYPE4)0 : vload4(0,input_mask+in_offset+shape.y*2);\n"
" temp_3=(sl*4+3 >= shape.x) ? (MASK_DTYPE4)0 : vload4(0,input_mask+in_offset+shape.y*3);\n"
" } else if(sl_kv*4+1 == shape.y){\n"
" temp_0=(MASK_DTYPE4)(input_mask[in_offset],0,0,0);\n"
" temp_1=(sl*4+1 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset+shape.y],0,0,0);//vload4(0,input_mask+in_offset+shape.y);\n"
" temp_2=(sl*4+2 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset+shape.y*2],0,0,0);//vload4(0,input_mask+in_offset+shape.y*2);\n"
" temp_3=(sl*4+3 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset+shape.y*3],0,0,0);//vload4(0,input_mask+in_offset+shape.y*3);\n"
" } else if(sl_kv*4+2 == shape.y){\n"
" temp_0=(MASK_DTYPE4)(input_mask[in_offset],input_mask[in_offset+1],0,0);\n"
" temp_1=(sl*4+1 >= shape.x) ? (MASK_DTYPE4)0 : (FLOAT4)(input_mask[in_offset+shape.y],input_mask[in_offset+shape.y+1],0,0);//vload4(0,input_mask+in_offset+shape.y);\n"
" temp_2=(sl*4+2 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset+shape.y*2],input_mask[in_offset+shape.y*2+1],0,0);//vload4(0,input_mask+in_offset+shape.y*2);\n"
" temp_3=(sl*4+3 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset+shape.y*3],input_mask[in_offset+shape.y*3+1],0,0);//vload4(0,input_mask+in_offset+shape.y*3);\n"
" } else if(sl_kv*4+3 == shape.y){\n"
" temp_0=(MASK_DTYPE4)(input_mask[in_offset],input_mask[in_offset+1],input_mask[in_offset+2],0);\n"
" temp_1=(sl*4+1 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset+shape.y],input_mask[in_offset+shape.y+1],input_mask[in_offset+shape.y+2],0);//vload4(0,input_mask+in_offset+shape.y);\n"
" temp_2=(sl*4+2 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset+shape.y*2],input_mask[in_offset+shape.y*2+1],input_mask[in_offset+shape.y*2+2],0);//vload4(0,input_mask+in_offset+shape.y*2);\n"
" temp_3=(sl*4+3 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset+shape.y*3],input_mask[in_offset+shape.y*3+1],input_mask[in_offset+shape.y*3+2],0);//vload4(0,input_mask+in_offset+shape.y*3);\n"
" }\n"
" vstore4(temp_0,0,output_mask+out_offset);\n"
" vstore4(temp_1,0,output_mask+out_offset+seq_len_kv_pack);\n"
" vstore4(temp_2,0,output_mask+out_offset+2*seq_len_kv_pack);\n"
" vstore4(temp_3,0,output_mask+out_offset+3*seq_len_kv_pack);\n"
" }\n"
"}\n"
"__kernel void qkv_transpose_output(GLOBAL_SIZE_3_DIMS\n"
" __global const FLOAT *input,// [Batch*mNumHead,ROUND_UP(mHeadDim,mTileHDN),ROUND_UP(seqLen,mTileQ)]\n"
" __global FLOAT *output,// [Batch,seqLen/4,mNumHead， mHeadDim,4]\n"
" __private const int tile_q,\n"
" __private const int tile_hdn,\n"
" __private const int seq_len,\n"
" __private const int head_num,\n"
" __private const int head_dim\n"
") {\n"
" \n"
" const int sl=get_global_id(0); // seqLen_4\n"
" const int hd=get_global_id(1); // mHeadDim_4\n"
" const int z=get_global_id(2); // Batch*mNumHead\n"
" DEAL_NON_UNIFORM_DIM3(sl,hd,z);\n"
" \n"
" const int b=z/head_num;\n"
" const int hn=z % head_num;\n"
" \n"
" const int seq_len_pack=((seq_len+tile_q-1)/tile_q)*tile_q;\n"
" const int head_dim_pack=((head_dim+tile_hdn-1)/tile_hdn)*tile_hdn;\n"
" \n"
" const int offset_inp=((b*head_num+hn)*head_dim_pack+4*hd)*seq_len_pack+4*sl;\n"
" \n"
" const int offset_out=(((b*seq_len+sl*4)*head_num+hn)*head_dim+4*hd);\n"
" \n"
" // Q\n"
" FLOAT4 temp_0=vload4(0,input+offset_inp);\n"
" FLOAT4 temp_1=vload4(0,input+offset_inp+seq_len_pack);\n"
" FLOAT4 temp_2=vload4(0,input+offset_inp+2*seq_len_pack);\n"
" FLOAT4 temp_3=vload4(0,input+offset_inp+3*seq_len_pack);\n"
" \n"
" vstore4((FLOAT4)(temp_0.s0,temp_1.s0,temp_2.s0,temp_3.s0),0,output+offset_out);\n"
" if(4*sl+1 >= seq_len) return;\n"
" vstore4((FLOAT4)(temp_0.s1,temp_1.s1,temp_2.s1,temp_3.s1),0,output+offset_out+head_num*head_dim);\n"
" if(4*sl+2 >= seq_len) return;\n"
" vstore4((FLOAT4)(temp_0.s2,temp_1.s2,temp_2.s2,temp_3.s2),0,output+offset_out+2*head_num*head_dim);\n"
" if(4*sl+3 >= seq_len) return;\n"
" vstore4((FLOAT4)(temp_0.s3,temp_1.s3,temp_2.s3,temp_3.s3),0,output+offset_out+3*head_num*head_dim);\n"
"}\n"
"#ifndef NUMHEAD_GROUP_SIZE\n"
"#define NUMHEAD_GROUP_SIZE 1\n"
"#endif\n"
"__kernel void matmul_qk_div_mask(GLOBAL_SIZE_3_DIMS\n"
" __global const FLOAT *input0,// query [1 query_seq_len head_num head_dim]\n"
" __global const FLOAT *input1,// key [1 key_seq_len head_num head_dim]\n"
" __global FLOAT *output,// prefill [1 head_num query_seq_len key_seq_len] decode[1 head_num key_seq_len/4 4]\n"
" __global FLOAT *past_key,// [1 max_length head_num head_dim]\n"
" #ifdef ADD_MASK\n"
" __global const FLOAT* mask,\n"
" #else\n"
" __global const int* mask,// [1 1 query_seq_len key_seq_len]\n"
" #endif\n"
" __private const float scale,\n"
" __private const int query_seq_len,\n"
" __private const int key_seq_len,\n"
" __private const int head_num,\n"
" __private const int kv_head_num,\n"
" __private const int head_dim) {\n"
" \n"
" const int x=get_global_id(0); // key_seq_len\n"
" const int y=get_global_id(1); // query_seq_len for prefill 1 for decode\n"
" const int z=get_global_id(2); // head_num\n"
" DEAL_NON_UNIFORM_DIM3(x,y,z);\n"
" \n"
" int x4=x << 2;\n"
" int y4=y << 2;\n"
" int zin=z/NUMHEAD_GROUP_SIZE;\n"
" __global const FLOAT *A_offset=input0+(y4*head_num+z)*head_dim;\n"
" __global FLOAT *Pastkey_offset=past_key+(x4*kv_head_num+zin)*head_dim;\n"
" int strideA=head_num*head_dim;\n"
" int strideB=kv_head_num*head_dim;\n"
"#ifdef OPENCL_PREFILL_ATTENTION\n"
" __global const FLOAT *B_offset=input1+(x4*kv_head_num+zin)*head_dim;\n"
" int output_offset=(z*query_seq_len+y4)*key_seq_len+x4;\n"
" float4 out0=0;\n"
" float4 out1=0;\n"
" float4 out2=0;\n"
" float4 out3=0;\n"
" \n"
" bool A1_enable=y4+1<query_seq_len;\n"
" bool A2_enable=y4+2<query_seq_len;\n"
" bool A3_enable=y4+3<query_seq_len;\n"
" \n"
" bool B1_enable=x4+1<key_seq_len;\n"
" bool B2_enable=x4+2<key_seq_len;\n"
" bool B3_enable=x4+3<key_seq_len;\n"
" \n"
" const int head_dim4=(head_dim+3)/4;\n"
" #ifdef HEADDIM_LEAVE\n"
" for(int i=0; i<head_dim4-1; ++i){\n"
" float4 A0=convert_float4(vload4(i,A_offset));\n"
" float4 A1=A1_enable ? convert_float4(vload4(i,A_offset+strideA)) : (float4)0;\n"
" float4 A2=A2_enable ? convert_float4(vload4(i,A_offset+strideA+strideA)) : (float4)0;\n"
" float4 A3=A3_enable ? convert_float4(vload4(i,A_offset+strideA+strideA+strideA)) : (float4)0;\n"
" float4 B0=convert_float4(vload4(i,B_offset));\n"
" float4 B1=B1_enable ? convert_float4(vload4(i,B_offset+strideB)) : (float4)0;\n"
" float4 B2=B2_enable ? convert_float4(vload4(i,B_offset+strideB+strideB)) : (float4)0;\n"
" float4 B3=B3_enable ? convert_float4(vload4(i,B_offset+strideB+strideB+strideB)) : (float4)0;\n"
" \n"
" out0.x += dot(A0,B0);\n"
" out0.y += dot(A0,B1);\n"
" out0.z += dot(A0,B2);\n"
" out0.w += dot(A0,B3);\n"
" \n"
" out1.x += dot(A1,B0);\n"
" out1.y += dot(A1,B1);\n"
" out1.z += dot(A1,B2);\n"
" out1.w += dot(A1,B3);\n"
" \n"
" out2.x += dot(A2,B0);\n"
" out2.y += dot(A2,B1);\n"
" out2.z += dot(A2,B2);\n"
" out2.w += dot(A2,B3);\n"
" \n"
" out3.x += dot(A3,B0);\n"
" out3.y += dot(A3,B1);\n"
" out3.z += dot(A3,B2);\n"
" out3.w += dot(A3,B3);\n"
" \n"
" vstore4(CONVERT_FLOAT4(B0),i,Pastkey_offset);\n"
" vstore4(CONVERT_FLOAT4(B1),i,Pastkey_offset+strideB);\n"
" vstore4(CONVERT_FLOAT4(B2),i,Pastkey_offset+strideB+strideB);\n"
" vstore4(CONVERT_FLOAT4(B3),i,Pastkey_offset+strideB+strideB+strideB);\n"
" }\n"
" for(int i=(head_dim4-1)*4; i<head_dim; ++i){\n"
" float A0=A_offset[i];\n"
" float A1=A1_enable ? A_offset[i+strideA] : 0;\n"
" float A2=A2_enable ? A_offset[i+strideA+strideA] : 0;\n"
" float A3=A3_enable ? A_offset[i+strideA+strideA+strideA] : 0;\n"
" float B0=B_offset[i];\n"
" float B1=B1_enable ? B_offset[i+strideB] : 0;\n"
" float B2=B2_enable ? B_offset[i+strideB+strideB] : 0;\n"
" float B3=B3_enable ? B_offset[i+strideB+strideB+strideB] : 0;\n"
" \n"
" out0.x += A0*B0;\n"
" out0.y += A0*B1;\n"
" out0.z += A0*B2;\n"
" out0.w += A0*B3;\n"
" \n"
" out1.x += A1*B0;\n"
" out1.y += A1*B1;\n"
" out1.z += A1*B2;\n"
" out1.w += A1*B3\n"
" \n"
" out2.x += A2*B0;\n"
" out2.y += A2*B1;\n"
" out2.z += A2*B2;\n"
" out2.w += A2*B3;\n"
" \n"
" out3.x += A3*B0;\n"
" out3.y += A3*B1;\n"
" out3.z += A3*B2;\n"
" out3.w += A3*B3;\n"
" \n"
" Pastkey_offset[i]=(FLOAT)B0;\n"
" Pastkey_offset[i+strideB]=(FLOAT)B1;\n"
" Pastkey_offset[i+strideB+strideB]=(FLOAT)B2;\n"
" Pastkey_offset[i+strideB+strideB+strideB]=(FLOAT)B3;\n"
" }\n"
" #else\n"
" for(int i=0; i<head_dim4; ++i){\n"
" float4 A0=convert_float4(vload4(i,A_offset));\n"
" float4 A1=A1_enable ? convert_float4(vload4(i,A_offset+strideA)) : (float4)0;\n"
" float4 A2=A2_enable ? convert_float4(vload4(i,A_offset+strideA+strideA)) : (float4)0;\n"
" float4 A3=A3_enable ? convert_float4(vload4(i,A_offset+strideA+strideA+strideA)) : (float4)0;\n"
" float4 B0=convert_float4(vload4(i,B_offset));\n"
" float4 B1=B1_enable ? convert_float4(vload4(i,B_offset+strideB)) : (float4)0;\n"
" float4 B2=B2_enable ? convert_float4(vload4(i,B_offset+strideB+strideB)) : (float4)0;\n"
" float4 B3=B3_enable ? convert_float4(vload4(i,B_offset+strideB+strideB+strideB)) : (float4)0;\n"
" \n"
" out0.x += dot(A0,B0);\n"
" out0.y += dot(A0,B1);\n"
" out0.z += dot(A0,B2);\n"
" out0.w += dot(A0,B3);\n"
" \n"
" out1.x += dot(A1,B0);\n"
" out1.y += dot(A1,B1);\n"
" out1.z += dot(A1,B2);\n"
" out1.w += dot(A1,B3);\n"
" \n"
" out2.x += dot(A2,B0);\n"
" out2.y += dot(A2,B1);\n"
" out2.z += dot(A2,B2);\n"
" out2.w += dot(A2,B3);\n"
" \n"
" out3.x += dot(A3,B0);\n"
" out3.y += dot(A3,B1);\n"
" out3.z += dot(A3,B2);\n"
" out3.w += dot(A3,B3);\n"
" \n"
" vstore4(CONVERT_FLOAT4(B0),i,Pastkey_offset);\n"
" vstore4(CONVERT_FLOAT4(B1),i,Pastkey_offset+strideB);\n"
" vstore4(CONVERT_FLOAT4(B2),i,Pastkey_offset+strideB+strideB);\n"
" vstore4(CONVERT_FLOAT4(B3),i,Pastkey_offset+strideB+strideB+strideB);\n"
" }\n"
" #endif\n"
" out0 *= (float4)scale;\n"
" out1 *= (float4)scale;\n"
" out2 *= (float4)scale;\n"
" out3 *= (float4)scale;\n"
" float4 mask0=convert_float4(vload4(0,mask+y4*key_seq_len+x4));\n"
" float4 mask1=convert_float4(vload4(0,mask+(y4+1)*key_seq_len+x4));\n"
" float4 mask2=convert_float4(vload4(0,mask+(y4+2)*key_seq_len+x4));\n"
" float4 mask3=convert_float4(vload4(0,mask+(y4+3)*key_seq_len+x4));\n"
" #ifdef ADD_MASK\n"
" out0 += mask0;\n"
" out1 += mask1;\n"
" out2 += mask2;\n"
" out3 += mask3;\n"
" #else\n"
" out0=(mask0 == (float4)0) ? (float4)(-FLT_MAX) : out0;\n"
" out1=(mask1 == (float4)0) ? (float4)(-FLT_MAX) : out1;\n"
" out2=(mask2 == (float4)0) ? (float4)(-FLT_MAX) : out2;\n"
" out3=(mask3 == (float4)0) ? (float4)(-FLT_MAX) : out3;\n"
" #endif\n"
" if(B3_enable){\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+output_offset);\n"
" if(!A1_enable) return;\n"
" output_offset += key_seq_len;\n"
" vstore4(CONVERT_FLOAT4(out1),0,output+output_offset);\n"
" if(!A2_enable) return;\n"
" output_offset += key_seq_len;\n"
" vstore4(CONVERT_FLOAT4(out2),0,output+output_offset);\n"
" if(!A3_enable) return;\n"
" output_offset += key_seq_len;\n"
" vstore4(CONVERT_FLOAT4(out3),0,output+output_offset);\n"
" } else if(B2_enable){\n"
" vstore3(CONVERT_FLOAT3((float3)(out0.x,out0.y,out0.z)),0,output+output_offset);\n"
" if(!A1_enable) return;\n"
" output_offset += key_seq_len;\n"
" vstore3(CONVERT_FLOAT3((float3)(out1.x,out1.y,out1.z)),0,output+output_offset);\n"
" if(!A2_enable) return;\n"
" output_offset += key_seq_len;\n"
" vstore3(CONVERT_FLOAT3((float3)(out2.x,out2.y,out2.z)),0,output+output_offset);\n"
" if(!A3_enable) return;\n"
" output_offset += key_seq_len;\n"
" vstore3(CONVERT_FLOAT3((float3)(out3.x,out3.y,out3.z)),0,output+output_offset);\n"
" } else if(B1_enable){\n"
" vstore2(CONVERT_FLOAT2((float2)(out0.x,out0.y)),0,output+output_offset);\n"
" if(!A1_enable) return;\n"
" output_offset += key_seq_len;\n"
" vstore2(CONVERT_FLOAT2((float2)(out1.x,out1.y)),0,output+output_offset);\n"
" if(!A2_enable) return;\n"
" output_offset += key_seq_len;\n"
" vstore2(CONVERT_FLOAT2((float2)(out2.x,out2.y)),0,output+output_offset);\n"
" if(!A3_enable) return;\n"
" output_offset += key_seq_len;\n"
" vstore2(CONVERT_FLOAT2((float2)(out3.x,out3.y)),0,output+output_offset);\n"
" } else {\n"
" output[output_offset]=out0.x;\n"
" if(!A1_enable) return;\n"
" output[output_offset+key_seq_len]=out1.x;\n"
" if(!A2_enable) return;\n"
" output[output_offset+key_seq_len+key_seq_len]=out2.x;\n"
" if(!A3_enable) return;\n"
" output[output_offset+key_seq_len+key_seq_len+key_seq_len]=out3.x;\n"
" }\n"
"#else\n"
" float4 out=0;\n"
" const int head_dim4=(head_dim+3)/4;\n"
" int key_seq_len4=(key_seq_len+3)/4;\n"
" #ifdef HEADDIM_LEAVE\n"
" for(int i=0; i<head_dim4-1; ++i){\n"
" float4 A=convert_float4(vload4(i,A_offset));\n"
" float4 B0=convert_float4(vload4(i,Pastkey_offset));\n"
" float4 B1=convert_float4(vload4(i,Pastkey_offset+strideB));\n"
" float4 B2=convert_float4(vload4(i,Pastkey_offset+strideB+strideB));\n"
" float4 B3=convert_float4(vload4(i,Pastkey_offset+strideB+strideB+strideB));\n"
" \n"
" out.x += dot(A,B0);\n"
" out.y += dot(A,B1);\n"
" out.z += dot(A,B2);\n"
" out.w += dot(A,B3);\n"
" }\n"
" for(int i=(head_dim4-1)*4; i<head_dim; ++i){\n"
" float A=A_offset[i];\n"
" float B0=Pastkey_offset[i];\n"
" float B1=Pastkey_offset[i+strideB];\n"
" float B2=Pastkey_offset[i+strideB+strideB];\n"
" float B3=Pastkey_offset[i+strideB+strideB];\n"
" out.x += A*B0;\n"
" out.y += A*B1;\n"
" out.z += A*B2;\n"
" out.w += A*B3;\n"
" }\n"
" #else\n"
" for(int i=0; i<head_dim4; ++i){\n"
" float4 A=convert_float4(vload4(i,A_offset));\n"
" float4 B0=convert_float4(vload4(i,Pastkey_offset));\n"
" float4 B1=convert_float4(vload4(i,Pastkey_offset+strideB));\n"
" float4 B2=convert_float4(vload4(i,Pastkey_offset+strideB+strideB));\n"
" float4 B3=convert_float4(vload4(i,Pastkey_offset+strideB+strideB+strideB));\n"
" \n"
" out.x += dot(A,B0);\n"
" out.y += dot(A,B1);\n"
" out.z += dot(A,B2);\n"
" out.w += dot(A,B3);\n"
" }\n"
" #endif\n"
" int remain=key_seq_len-x4;\n"
" if(x == key_seq_len4-1){\n"
" __global const FLOAT *B_offset=input1+zin*head_dim;\n"
" Pastkey_offset += (remain-1)*strideB;\n"
" float tmp=0;\n"
" #ifdef HEADDIM_LEAVE\n"
" for(int i=0; i<head_dim4-1; ++i){\n"
" float4 A=convert_float4(vload4(i,A_offset));\n"
" float4 B=convert_float4(vload4(i,B_offset));\n"
" \n"
" tmp += dot(A,B);\n"
" vstore4(CONVERT_FLOAT4(B),i,Pastkey_offset);\n"
" }\n"
" for(int i=(head_dim4-1)*4; i<head_dim; ++i){\n"
" float A=A_offset[i];\n"
" float B=B_offset[i];\n"
" tmp += A*B;\n"
" Pastkey_offset[i]=B;\n"
" }\n"
" #else\n"
" for(int i=0; i<head_dim4; ++i){\n"
" float4 A=convert_float4(vload4(i,A_offset));\n"
" float4 B=convert_float4(vload4(i,B_offset));\n"
" \n"
" tmp += dot(A,B);\n"
" vstore4(CONVERT_FLOAT4(B),i,Pastkey_offset);\n"
" }\n"
" #endif\n"
" float *out_ptr=(float*)&out;\n"
" out_ptr[remain-1]=tmp;\n"
" }\n"
" out *= (float4)scale;\n"
" if(remain >= 4){\n"
" vstore4(CONVERT_FLOAT4(out),0,output+z*key_seq_len+x4);\n"
" } else if (remain >= 3){\n"
" vstore3(CONVERT_FLOAT3((float3)(out.x,out.y,out.z)),0,output+z*key_seq_len+x4);\n"
" } else if (remain >= 2){\n"
" vstore2(CONVERT_FLOAT2((float2)(out.x,out.y)),0,output+z*key_seq_len+x4);\n"
" } else {\n"
" output[z*key_seq_len+x4]=out.x;\n"
" }\n"
"#endif\n"
"}\n"
"__kernel void matmul_qkv(GLOBAL_SIZE_3_DIMS\n"
" __global const FLOAT *input0,// qk prefill [1 head_num qk_seq_len value_seq_len] decode[1 head_num value_seq_len]\n"
" __global const FLOAT *input1,// [1 value_seq_len head_num head_dim]\n"
" __global FLOAT *output,// [1 qk_seq_len head_num head_dim]\n"
" __global FLOAT *past_value,// [1 value_seq_len head_num head_dim]\n"
" __private const int qk_seq_len,\n"
" __private const int value_seq_len,\n"
" __private const int head_num,\n"
" __private const int kv_head_num,\n"
" __private const int head_dim) {\n"
" \n"
" const int x=get_global_id(0); // head_dim << 2\n"
" const int y=get_global_id(1); // head_num\n"
" const int z=get_global_id(2); // prefill qk_seq_len decode 1\n"
" \n"
" const int x4=x << 2;\n"
" DEAL_NON_UNIFORM_DIM3(x,y,z);\n"
" \n"
" const int yin=y/NUMHEAD_GROUP_SIZE;\n"
"#ifdef OPENCL_PREFILL_ATTENTION\n"
" int z4=z << 2;\n"
" int value_seq_len4=(value_seq_len+3)/4;\n"
" int loop_end=max(value_seq_len4-1,0);\n"
" const int stride=kv_head_num*head_dim;\n"
" __global const FLOAT *A_offset=input0+(y*qk_seq_len+z4)*value_seq_len;\n"
" __global const FLOAT *B_offset=input1+yin*head_dim+x4;\n"
" __global FLOAT *Pastvalue_offset=past_value+yin*head_dim+x4;\n"
" COMPUTE_FLOAT4 out0=0;\n"
" COMPUTE_FLOAT4 out1=0;\n"
" COMPUTE_FLOAT4 out2=0;\n"
" COMPUTE_FLOAT4 out3=0;\n"
" \n"
" for(int i=0; i<loop_end; ++i){\n"
" int index=i << 2;\n"
" COMPUTE_FLOAT4 A0=CONVERT_COMPUTE_FLOAT4(vload4(i,A_offset));\n"
" COMPUTE_FLOAT4 A1=CONVERT_COMPUTE_FLOAT4(vload4(i,A_offset+value_seq_len));\n"
" COMPUTE_FLOAT4 A2=CONVERT_COMPUTE_FLOAT4(vload4(i,A_offset+value_seq_len+value_seq_len));\n"
" COMPUTE_FLOAT4 A3=CONVERT_COMPUTE_FLOAT4(vload4(i,A_offset+value_seq_len+value_seq_len+value_seq_len));\n"
" COMPUTE_FLOAT4 B0=CONVERT_COMPUTE_FLOAT4(vload4(0,B_offset+(index+0)*stride));\n"
" COMPUTE_FLOAT4 B1=CONVERT_COMPUTE_FLOAT4(vload4(0,B_offset+(index+1)*stride));\n"
" COMPUTE_FLOAT4 B2=CONVERT_COMPUTE_FLOAT4(vload4(0,B_offset+(index+2)*stride));\n"
" COMPUTE_FLOAT4 B3=CONVERT_COMPUTE_FLOAT4(vload4(0,B_offset+(index+3)*stride));\n"
" \n"
" out0=mad(B0,(COMPUTE_FLOAT4)A0.x,out0);\n"
" out0=mad(B1,(COMPUTE_FLOAT4)A0.y,out0);\n"
" out0=mad(B2,(COMPUTE_FLOAT4)A0.z,out0);\n"
" out0=mad(B3,(COMPUTE_FLOAT4)A0.w,out0);\n"
" \n"
" out1=mad(B0,(COMPUTE_FLOAT4)A1.x,out1);\n"
" out1=mad(B1,(COMPUTE_FLOAT4)A1.y,out1);\n"
" out1=mad(B2,(COMPUTE_FLOAT4)A1.z,out1);\n"
" out1=mad(B3,(COMPUTE_FLOAT4)A1.w,out1);\n"
" \n"
" out2=mad(B0,(COMPUTE_FLOAT4)A2.x,out2);\n"
" out2=mad(B1,(COMPUTE_FLOAT4)A2.y,out2);\n"
" out2=mad(B2,(COMPUTE_FLOAT4)A2.z,out2);\n"
" out2=mad(B3,(COMPUTE_FLOAT4)A2.w,out2);\n"
" \n"
" out3=mad(B0,(COMPUTE_FLOAT4)A3.x,out3);\n"
" out3=mad(B1,(COMPUTE_FLOAT4)A3.y,out3);\n"
" out3=mad(B2,(COMPUTE_FLOAT4)A3.z,out3);\n"
" out3=mad(B3,(COMPUTE_FLOAT4)A3.w,out3);\n"
" vstore4(CONVERT_FLOAT4(B0),0,Pastvalue_offset+(index+0)*stride);\n"
" vstore4(CONVERT_FLOAT4(B1),0,Pastvalue_offset+(index+1)*stride);\n"
" vstore4(CONVERT_FLOAT4(B2),0,Pastvalue_offset+(index+2)*stride);\n"
" vstore4(CONVERT_FLOAT4(B3),0,Pastvalue_offset+(index+3)*stride);\n"
" }\n"
" for(int i=loop_end << 2; i<value_seq_len; ++i){\n"
" COMPUTE_FLOAT A0=A_offset[i];\n"
" COMPUTE_FLOAT A1=A_offset[i+value_seq_len];\n"
" COMPUTE_FLOAT A2=A_offset[i+value_seq_len+value_seq_len];\n"
" COMPUTE_FLOAT A3=A_offset[i+value_seq_len+value_seq_len+value_seq_len];\n"
" COMPUTE_FLOAT4 B=CONVERT_COMPUTE_FLOAT4(vload4(0,B_offset+i*stride));\n"
" \n"
" out0=mad(B,(COMPUTE_FLOAT4)A0,out0);\n"
" out1=mad(B,(COMPUTE_FLOAT4)A1,out1);\n"
" out2=mad(B,(COMPUTE_FLOAT4)A2,out2);\n"
" out3=mad(B,(COMPUTE_FLOAT4)A3,out3);\n"
" vstore4(CONVERT_FLOAT4(B),0,Pastvalue_offset+i*stride);\n"
" }\n"
" \n"
" #ifdef HEADDIM_LEAVE\n"
" int remain=head_dim-x4;\n"
" int output_offset=(z4*head_num+y)*head_dim+x4;\n"
" if(remain >= 4){\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+output_offset);\n"
" } else if(remain == 3){\n"
" vstore3(CONVERT_FLOAT3((COMPUTE_FLOAT3)(out0.x,out0.y,out0.z)),0,output+output_offset);\n"
" } else if(remain == 2){\n"
" vstore2(CONVERT_FLOAT2((COMPUTE_FLOAT3)(out0.x,out0.y)),0,output+output_offset);\n"
" } else{\n"
" output[output_offset]=out0.x;\n"
" }\n"
" if(z4+1 >= qk_seq_len) return;\n"
" output_offset += head_num*head_dim;\n"
" if(remain >= 4){\n"
" vstore4(CONVERT_FLOAT4(out1),0,output+output_offset);\n"
" } else if(remain == 3){\n"
" vstore3(CONVERT_FLOAT3((COMPUTE_FLOAT3)(out1.x,out1.y,out1.z)),0,output+output_offset);\n"
" } else if(remain == 2){\n"
" vstore2(CONVERT_FLOAT2((COMPUTE_FLOAT3)(out1.x,out1.y)),0,output+output_offset);\n"
" } else{\n"
" output[output_offset]=out1.x;\n"
" }\n"
" if(z4+2 >= qk_seq_len) return;\n"
" output_offset += head_num*head_dim;\n"
" if(remain >= 4){\n"
" vstore4(CONVERT_FLOAT4(out2),0,output+output_offset);\n"
" } else if(remain == 3){\n"
" vstore3(CONVERT_FLOAT3((COMPUTE_FLOAT3)(out2.x,out2.y,out2.z)),0,output+output_offset);\n"
" } else if(remain == 2){\n"
" vstore2(CONVERT_FLOAT2((COMPUTE_FLOAT3)(out2.x,out2.y)),0,output+output_offset);\n"
" } else{\n"
" output[output_offset]=out2.x;\n"
" }\n"
" if(z4+3 >= qk_seq_len) return;\n"
" output_offset += head_num*head_dim;\n"
" if(remain >= 4){\n"
" vstore4(CONVERT_FLOAT4(out3),0,output+output_offset);\n"
" } else if(remain == 3){\n"
" vstore3(CONVERT_FLOAT3((COMPUTE_FLOAT3)(out3.x,out3.y,out3.z)),0,output+output_offset);\n"
" } else if(remain == 2){\n"
" vstore2(CONVERT_FLOAT2((COMPUTE_FLOAT3)(out3.x,out3.y)),0,output+output_offset);\n"
" } else{\n"
" output[(x*head_num+y)*head_dim+z4]=out3.x;\n"
" }\n"
" #else\n"
" int output_offset=(z4*head_num+y)*head_dim+x4;\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+output_offset);\n"
" if(z4+1 >= qk_seq_len) return;\n"
" output_offset += head_num*head_dim;\n"
" vstore4(CONVERT_FLOAT4(out1),0,output+output_offset);\n"
" if(z4+2 >= qk_seq_len) return;\n"
" output_offset += head_num*head_dim;\n"
" vstore4(CONVERT_FLOAT4(out2),0,output+output_offset);\n"
" if(z4+3 >= qk_seq_len) return;\n"
" output_offset += head_num*head_dim;\n"
" vstore4(CONVERT_FLOAT4(out3),0,output+output_offset);\n"
" #endif\n"
"#else\n"
" int value_seq_len4=(value_seq_len-1+3)/4;\n"
" int loop_end=max(value_seq_len4-1,0);\n"
" const int stride=kv_head_num*head_dim;\n"
" __global const FLOAT *A_offset=input0+y*value_seq_len;\n"
" __global const FLOAT *B_offset=input1+yin*head_dim+x4;\n"
" __global FLOAT *Pastvalue_offset=past_value+yin*head_dim+x4;\n"
" COMPUTE_FLOAT4 out=0;\n"
" \n"
" for(int i=0; i<loop_end; i++){\n"
" int index=i << 2;\n"
" COMPUTE_FLOAT4 A=CONVERT_COMPUTE_FLOAT4(vload4(i,A_offset));\n"
" COMPUTE_FLOAT4 B0=CONVERT_COMPUTE_FLOAT4(vload4(0,Pastvalue_offset+(index+0)*stride));\n"
" COMPUTE_FLOAT4 B1=CONVERT_COMPUTE_FLOAT4(vload4(0,Pastvalue_offset+(index+1)*stride));\n"
" COMPUTE_FLOAT4 B2=CONVERT_COMPUTE_FLOAT4(vload4(0,Pastvalue_offset+(index+2)*stride));\n"
" COMPUTE_FLOAT4 B3=CONVERT_COMPUTE_FLOAT4(vload4(0,Pastvalue_offset+(index+3)*stride));\n"
" \n"
" out=mad(B0,(COMPUTE_FLOAT4)A.x,out);\n"
" out=mad(B1,(COMPUTE_FLOAT4)A.y,out);\n"
" out=mad(B2,(COMPUTE_FLOAT4)A.z,out);\n"
" out=mad(B3,(COMPUTE_FLOAT4)A.w,out);\n"
" }\n"
" for(int i=loop_end << 2; i<value_seq_len-1; i++){\n"
" COMPUTE_FLOAT A=A_offset[i];\n"
" COMPUTE_FLOAT4 B=CONVERT_COMPUTE_FLOAT4(vload4(0,Pastvalue_offset+i*stride));\n"
" \n"
" out=mad(B,(COMPUTE_FLOAT4)A,out);\n"
" }\n"
" COMPUTE_FLOAT A=A_offset[value_seq_len-1];\n"
" COMPUTE_FLOAT4 B=CONVERT_COMPUTE_FLOAT4(vload4(0,B_offset));\n"
" out=mad(B,(COMPUTE_FLOAT4)A,out);\n"
" \n"
" #ifdef HEADDIM_LEAVE\n"
" int remain=head_dim-x4;\n"
" if(remain >= 4){\n"
" vstore4(CONVERT_FLOAT4(out),0,output+y*head_dim+x4);\n"
" vstore4(CONVERT_FLOAT4(B),0,Pastvalue_offset+(value_seq_len-1)*stride);\n"
" } else if(remain == 3){\n"
" vstore3(CONVERT_FLOAT3((COMPUTE_FLOAT3)(out.x,out.y,out.z)),0,output+y*head_dim+x4);\n"
" vstore3(CONVERT_FLOAT4((COMPUTE_FLOAT3)(B.x,B.y,B.z)),0,Pastvalue_offset+(value_seq_len-1)*stride);\n"
" } else if(remain == 2){\n"
" vstore2(CONVERT_FLOAT2((COMPUTE_FLOAT3)(out.x,out.y)),0,output+y*head_dim+x4);\n"
" vstore2(CONVERT_FLOAT4((COMPUTE_FLOAT3)(B.x,B.y)),0,Pastvalue_offset+(value_seq_len-1)*stride);\n"
" } else{\n"
" output[(x*head_num+y)*head_dim+x4]=out.x;\n"
" Pastvalue_offset[(value_seq_len-1)*stride]=B.x;\n"
" }\n"
" #else\n"
" vstore4(CONVERT_FLOAT4(B),0,Pastvalue_offset+(value_seq_len-1)*stride);\n"
" vstore4(CONVERT_FLOAT4(out),0,output+y*head_dim+x4);\n"
" #endif\n"
" \n"
"#endif\n"
"}\n"
;
#endif
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* groupnorm_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"__kernel void groupnorm_plain_buf(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
"#ifdef DOUBLE_INPUTS\n"
" __global const FLOAT*input0,\n"
" __global const FLOAT*input1,\n"
"#else\n"
" __global const FLOAT*input,\n"
"#endif\n"
" __global FLOAT*output,\n"
" __private const int area,\n"
" __private const int group,\n"
" __private const int inside,\n"
" __private const int outside,\n"
"#ifdef GAMMA_BETA\n"
" __global const FLOAT *gamma,\n"
" __global const FLOAT *beta,\n"
"#endif\n"
" __private float epsilon){\n"
" int3 pos=(int3)(get_global_id(0),get_global_id(1),get_global_id(2));\n"
" float local sum[LOCAL_SIZE];\n"
" if (pos.x<global_dim0 && pos.y<global_dim1 && pos.z<global_dim2) {\n"
" const int idx_out=pos.z;\n"
" const int lid=get_local_id(0);\n"
" const int offset=idx_out*inside;\n"
" const int inside_v4=(inside+3) >> 2;\n"
" \n"
"#ifdef DOUBLE_INPUTS\n"
" // The product of W and H is a multiple of 4\n"
" #ifdef WH_4\n"
" float4 in_sum=0;\n"
" int index=lid;\n"
" for(; index<inside_v4; index+=LOCAL_SIZE){\n"
" float4 in0=convert_float4(vload4(index,input0+offset));\n"
" in_sum += in0;\n"
" float in1=input1[idx_out*(inside/area)+index/(area/4)];\n"
" in_sum += (float4)(in1,in1,in1,in1);\n"
" }\n"
" sum[lid]=in_sum.x+in_sum.y+in_sum.z+ in_sum.w;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=sum[lid]+sum[lid+i];\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" \n"
" float4 mean=sum[0]/(float4)inside;\n"
" in_sum=0;\n"
" index=lid;\n"
" for(; index<inside_v4; index+=LOCAL_SIZE){\n"
" float4 in0=convert_float4(vload4(index,input0+offset));\n"
" float in1=input1[idx_out*(inside/area)+index/(area/4)];\n"
" in_sum += (in0+(float4)(in1,in1,in1,in1)-mean)*(in0+(float4)in1-mean);\n"
" }\n"
" sum[lid]=in_sum.x+in_sum.y+in_sum.z+in_sum.w;\n"
" \n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=sum[lid]+sum[lid+i];\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" float4 square_sum=(float4)(sum[0]/inside);\n"
" float4 value=(float4)(1.0f/sqrt(square_sum.x+epsilon));\n"
" for(int i=lid; i<inside_v4; i+=LOCAL_SIZE){\n"
" float4 in0=convert_float4(vload4(i,input0+offset));\n"
" float in1=input1[idx_out*(inside/area)+i/(area/4)];\n"
" float4 out=(in0+(float4)(in1,in1,in1,in1)-mean)*value;\n"
" #ifdef GAMMA_BETA\n"
" int offset_gamma_beta=(idx_out % group)*inside/area+i/(area/4);\n"
" out=out*(float4)((float)gamma[offset_gamma_beta])+(float4)((float)beta[offset_gamma_beta]);\n"
" #endif\n"
" #ifdef SWISH\n"
" out=out*native_recip((float4)1+native_exp(convert_float4(-out)));\n"
" #endif\n"
" vstore4(CONVERT_FLOAT4(out),i,output+offset);\n"
" }\n"
" #else\n"
" \n"
" float in_sum=0;\n"
" int index=lid;\n"
" for(; index<inside; index+=LOCAL_SIZE){\n"
" float in0=input0[offset+index];\n"
" in_sum += in0;\n"
" float in1=input1[idx_out*(inside/area)+index/area];\n"
" in_sum += in1;\n"
" }\n"
" sum[lid]=in_sum;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=sum[lid]+sum[lid+i];\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" float mean=sum[0]/inside;\n"
" in_sum=0;\n"
" index=lid;\n"
" for(; index<inside; index+=LOCAL_SIZE){\n"
" float in0=input0[offset+index];\n"
" float in1=input1[idx_out*(inside/area)+index/area];\n"
" in_sum += (in0+in1-mean)*(in0+in1-mean);\n"
" }\n"
" sum[lid]=in_sum;\n"
" \n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=sum[lid]+sum[lid+i];\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" float square_sum=sum[0]/inside;\n"
" float value=1.0f/sqrt(square_sum+epsilon);\n"
" for(int i=lid; i<inside; i+=LOCAL_SIZE){\n"
" float in0=input0[offset+i];\n"
" float in1=input1[idx_out*(inside/area)+i/area];\n"
" float out=(in0+in1-mean)*value;\n"
" #ifdef GAMMA_BETA\n"
" int offset_gamma_beta=(idx_out % group)*inside/area+i/area;\n"
" out=out*(float)gamma[offset_gamma_beta]+(float)beta[offset_gamma_beta];\n"
" #endif\n"
" #ifdef SWISH\n"
" out=out*native_recip(1.0+native_exp(-out));\n"
" #endif\n"
" output[offset+i]=(FLOAT)out;\n"
" }\n"
" \n"
" #endif\n"
"#else\n"
" const int inside_remain=inside-((inside_v4-1) << 2);\n"
" float4 in_sum=0;\n"
" int index=lid;\n"
" for(; index<inside_v4-1; index+=LOCAL_SIZE){\n"
" float4 in=convert_float4(vload4(index,input+offset));\n"
" in_sum += in;\n"
" }\n"
" sum[lid]=in_sum.x+in_sum.y+in_sum.z+ in_sum.w;\n"
" \n"
" float4 in_left=0;\n"
" if(index == inside_v4-1) {\n"
" in_left=convert_float4(vload4(inside_v4-1,input+offset));\n"
" sum[lid]=sum[lid]+in_left.x;\n"
" if(inside_remain>1) {\n"
" sum[lid]=sum[lid]+in_left.y;\n"
" }\n"
" if(inside_remain>2) {\n"
" sum[lid]=sum[lid]+in_left.z;\n"
" }\n"
" if(inside_remain>3) {\n"
" sum[lid]=sum[lid]+in_left.w;\n"
" }\n"
" }\n"
" \n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=sum[lid]+sum[lid+i];\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" \n"
" float4 mean=(float4)(sum[0]/inside);\n"
" in_sum=0;\n"
" index=lid;\n"
" for(; index<inside_v4-1; index+=LOCAL_SIZE){\n"
" float4 in=convert_float4(vload4(index,input+offset));\n"
" in_sum += (in-mean)*(in-mean);\n"
" }\n"
" sum[lid]=in_sum.x+in_sum.y+in_sum.z+in_sum.w;\n"
" \n"
" if(index == inside_v4-1) {\n"
" float4 in_left=convert_float4(vload4(inside_v4-1,input+offset));\n"
" in_sum=(in_left-mean)*(in_left-mean);\n"
" sum[lid]=sum[lid]+in_sum.x;\n"
" if(inside_remain>1) {\n"
" sum[lid]=sum[lid]+in_sum.y;\n"
" }\n"
" if(inside_remain>2) {\n"
" sum[lid]=sum[lid]+in_sum.z;\n"
" }\n"
" if(inside_remain>3) {\n"
" sum[lid]=sum[lid]+in_sum.w;\n"
" }\n"
" }\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=sum[lid]+sum[lid+i];\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" float4 square_sum=(float4)(sum[0]/inside);\n"
" float4 value=(float4)(1.0f/sqrt(square_sum.x+epsilon));\n"
" // The product of W and H is a multiple of 4\n"
" #ifdef WH_4\n"
" for(int i=lid; i<inside_v4; i+=LOCAL_SIZE){\n"
" float4 in=convert_float4(vload4(i,input+offset));\n"
" float4 out=(in-mean)*value;\n"
" #ifdef GAMMA_BETA\n"
" int offset_gamma_beta=(idx_out % group)*inside/area+i/(area/4);\n"
" out=out*(float4)((float)gamma[offset_gamma_beta])+(float4)((float)beta[offset_gamma_beta]);\n"
" #endif\n"
" #ifdef SWISH\n"
" out=out*native_recip((float4)1+native_exp(convert_float4(-out)));\n"
" #endif\n"
" vstore4(CONVERT_FLOAT4(out),i,output+offset);\n"
" }\n"
" #else\n"
" for(int i=lid; i<inside; i+=LOCAL_SIZE){\n"
" float in=input[offset+i];\n"
" float out=(in-mean.x)*value.x;\n"
" #ifdef GAMMA_BETA\n"
" int offset_gamma_beta=(idx_out % group)*inside/area+i/area;\n"
" out=out*(float)gamma[offset_gamma_beta]+(float)beta[offset_gamma_beta];\n"
" #endif\n"
" #ifdef SWISH\n"
" out=out*native_recip(1.0+native_exp(-out));\n"
" #endif\n"
" \n"
" output[offset+i]=(FLOAT)out;\n"
" }\n"
" #endif\n"
"#endif\n"
" }\n"
"}\n"
;
#endif
#ifndef MNN_OPENCL_BUFFER_CLOSED
#ifdef MNN_SUPPORT_INTEL_SUBGROUP
const char* unary_subgroup_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"inline float4 gelu(float4 in){\n"
" float4 value=0.79788458f*(0.044715f*in*in*in+in);\n"
" float4 x2=value*value;\n"
" float4 dst=value>(float4)5.0f ? (float4)1.0f : (value <= -(float4)5.0f ? -(float4)1.0f :\n"
" (value*(135135.0f+x2*(17325.0f+x2*(378.0f+x2))))/(135135.0f+x2*(62370.0f+x2*(3150.0f+x2*28.0f))));\n"
" return (1.0f+dst)*in*0.5f;\n"
"}\n"
"__kernel void unary_buf_c4_c4(GLOBAL_SIZE_3_DIMS\n"
" __global const INPUT_TYPE *input,\n"
" __global OUTPUT_TYPE *output,\n"
" __private const int width,\n"
" __private const int height,\n"
" __private const int channel,\n"
" __private const int batch,\n"
" __private const int input_pad_left,__private const int input_pad_right,\n"
" __private const int output_pad_left,__private const int output_pad_right) {\n"
" const int channel_block_idx=get_global_id(0);\n"
" const int w=get_global_id(1);\n"
" const int hb=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(channel_block_idx,w,hb);\n"
" const int batch_idx=hb/height;\n"
" const int height_idx=hb % height;\n"
" const int offset=(((batch_idx+channel_block_idx*batch)*height+height_idx)*width+w)*4;\n"
" float4 in=convert_float4(vload4(0,input+offset));\n"
" float4 out=OPERATOR;\n"
" vstore4(CONVERT_OUTPUT4(out),0,output+offset);\n"
"}\n"
"__kernel void unary_buf_c4_c16(GLOBAL_SIZE_3_DIMS\n"
" __global const INPUT_TYPE *input,\n"
" __global OUTPUT_TYPE *output,\n"
" __private const int width,\n"
" __private const int height,\n"
" __private const int channel,\n"
" __private const int batch,\n"
" __private const int input_pad_left,__private const int input_pad_right,\n"
" __private const int output_pad_left,__private const int output_pad_right) {\n"
" const int channel_block_idx=get_global_id(0);\n"
" const int w=get_global_id(1);\n"
" const int hb=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(channel_block_idx,w,hb);\n"
" const int batch_idx=hb/height;\n"
" const int height_idx=hb % height;\n"
" const int dst_width=output_pad_left+width+output_pad_right;\n"
" const int channel16=(channel+15)/16;\n"
" const int channe_out_idx=channel_block_idx >> 2;\n"
" const int offset=(((batch_idx+channel_block_idx*batch)*height+height_idx)*width+w)*4;\n"
" const int dst_offset=(((batch_idx*channel16+channe_out_idx)*height+height_idx)*dst_width+w+output_pad_left)*16+(channel_block_idx % 4)*4;\n"
" float4 in=convert_float4(vload4(0,input+offset));\n"
" float4 out=OPERATOR;\n"
" vstore4(CONVERT_OUTPUT4(out),0,output+dst_offset);\n"
" if(w == 0){\n"
" int pad_offset=(((batch_idx*channel16+channe_out_idx)*height+height_idx)*dst_width)*16+(channel_block_idx % 4)*4;\n"
" for(int i=0; i<output_pad_left; ++i){\n"
" vstore4((OUTPUT_TYPE4)0,0,output+pad_offset+i*16);\n"
" }\n"
" pad_offset += (width+output_pad_left)*16;\n"
" for(int i=0; i<output_pad_right; ++i){\n"
" vstore4((OUTPUT_TYPE4)0,0,output+pad_offset+i*16);\n"
" }\n"
" }\n"
"}\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void unary_buf_c16_c16(GLOBAL_SIZE_3_DIMS\n"
" __global const INPUT_TYPE *input,\n"
" __global OUTPUT_TYPE *output,\n"
" __private const int width,\n"
" __private const int height,\n"
" __private const int channel,\n"
" __private const int batch,\n"
" __private const int input_pad_left,__private const int input_pad_right,\n"
" __private const int output_pad_left,__private const int output_pad_right) {\n"
" const int channel_idx=get_group_id(0);\n"
" const int w=get_global_id(1) << 2;\n"
" const int hb=get_global_id(2);\n"
" const int sglid=get_sub_group_local_id();\n"
" const int batch_idx=hb/height;\n"
" const int height_idx=hb % height;\n"
" const int src_width=width+input_pad_left+input_pad_right;\n"
" const int dst_width=width+output_pad_left+output_pad_right;\n"
" const int channel16=(channel+15)/16;\n"
" const int src_offset=(((batch_idx*channel16+channel_idx)*height+height_idx)*src_width+w+input_pad_left)*16;\n"
" const int dst_offset=(((batch_idx*channel16+channel_idx)*height+height_idx)*dst_width+w+output_pad_left)*16;\n"
" \n"
" float4 in=convert_float4(AS_INPUT_DATA4(INTEL_SUB_GROUP_READ4((__global INTEL_DATA*)(input+src_offset))));\n"
" float4 out=OPERATOR;\n"
" if (w+4>width) {\n"
" for (int i=0; i<width % 4; i++) {\n"
" output[dst_offset+i*16+sglid]=(OUTPUT_TYPE)out[i];\n"
" }\n"
" } else{\n"
" INTEL_SUB_GROUP_WRITE4((__global INTEL_DATA*)(output+dst_offset),AS_OUTPUT_DATA4(CONVERT_OUTPUT4(out)));\n"
" }\n"
" if(w == 0){\n"
" int pad_offset=(((batch_idx*channel+channel_idx)*height+height_idx)*dst_width)*16+sglid;\n"
" for(int i=0; i<output_pad_left; ++i){\n"
" output[pad_offset+i*16]=(OUTPUT_TYPE)0;\n"
" }\n"
" pad_offset += (width+output_pad_left)*16;\n"
" for(int i=0; i<output_pad_right; ++i){\n"
" output[pad_offset+i*16]=(OUTPUT_TYPE)0;\n"
" }\n"
" }\n"
"}\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void unary_buf_c16_c4(GLOBAL_SIZE_3_DIMS\n"
" __global const INPUT_TYPE *input,\n"
" __global OUTPUT_TYPE *output,\n"
" __private const int width,\n"
" __private const int height,\n"
" __private const int channel,\n"
" __private const int batch,\n"
" __private const int input_pad_left,__private const int input_pad_right,\n"
" __private const int output_pad_left,__private const int output_pad_right) {\n"
" const int channel_idx=get_group_id(0);\n"
" const int w=get_global_id(1) << 2;\n"
" const int hb=get_global_id(2);\n"
" const int sglid=get_sub_group_local_id();\n"
" const int batch_idx=hb/height;\n"
" const int height_idx=hb % height;\n"
" const int src_width=width+input_pad_left+input_pad_right;\n"
" const int channel16=(channel+15)/16;\n"
" const int src_offset=(((batch_idx*channel16+channel_idx)*height+height_idx)*src_width+w+input_pad_left)*16;\n"
" const int dst_offset=(((batch_idx+(channel_idx<<2)*batch)*height+height_idx)*width+w)*4;\n"
" const int height_width=height*width*4;\n"
" \n"
" float4 in=convert_float4(AS_INPUT_DATA4(INTEL_SUB_GROUP_READ4((__global INTEL_DATA*)(input+src_offset))));\n"
" float4 out=OPERATOR;\n"
" const int lid_x=sglid % 4;\n"
" const int lid_y=sglid/4;\n"
" int block_size=w+4>width ? (width % 4) : 4;\n"
" for (int i=0; i<block_size; i++) {\n"
" output[dst_offset+i*4+lid_y*height_width+lid_x]=(OUTPUT_TYPE)out[i];\n"
" }\n"
"}\n"
;
#endif
#endif
const char* gemm = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_DIM2 "" __private int global_size_dim0,__private int global_size_dim1,\n"
"#define UNIFORM_BOUNDRY_CHECK(index0, index1) "" if(index0 >= global_size_dim0 || index1 >= global_size_dim1) { "" return; "" }\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void gemm(__read_only image2d_t uInput,__read_only image2d_t uKernel,__write_only image2d_t uOutput,\n"
" __private const int width,__private const int height,__private const int multiLength,__private const int alpha2) {\n"
" \n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1)); \n"
" if (pos.x<width*height && pos.y<alpha2) {\n"
" \n"
" const int pos_x=pos.x % width;\n"
" const int pos_y=pos.x/width;\n"
" const int pos_z=pos.y;\n"
" FLOAT4 o0=(FLOAT4)(0);\n"
" FLOAT4 o1=(FLOAT4)(0);\n"
" FLOAT4 o2=(FLOAT4)(0);\n"
" FLOAT4 o3=(FLOAT4)(0);\n"
" int kenerlY=mad24(pos_z,height,pos_y);\n"
" int srcY=mad24(pos_z,width,pos_x);\n"
" for (int k=0; k<multiLength; ++k) {\n"
" __private int index=mul24(k,4);\n"
" FLOAT4 k0=RI_F(uKernel,SAMPLER,(int2)(index,kenerlY));\n"
" FLOAT4 k1=RI_F(uKernel,SAMPLER,(int2)(index+1,kenerlY));\n"
" FLOAT4 k2=RI_F(uKernel,SAMPLER,(int2)(index+2,kenerlY));\n"
" FLOAT4 k3=RI_F(uKernel,SAMPLER,(int2)(index+3,kenerlY));\n"
" FLOAT4 s0=RI_F(uInput,SAMPLER,(int2)(index,srcY));\n"
" FLOAT4 s1=RI_F(uInput,SAMPLER,(int2)(index+1,srcY));\n"
" FLOAT4 s2=RI_F(uInput,SAMPLER,(int2)(index+2,srcY));\n"
" FLOAT4 s3=RI_F(uInput,SAMPLER,(int2)(index+3,srcY));\n"
" o0=mad(s0.x,k0,o0);\n"
" o0=mad(s0.y,k1,o0);\n"
" o0=mad(s0.z,k2,o0);\n"
" o0=mad(s0.w,k3,o0);\n"
" o1=mad(s1.x,k0,o1);\n"
" o1=mad(s1.y,k1,o1);\n"
" o1=mad(s1.z,k2,o1);\n"
" o1=mad(s1.w,k3,o1);\n"
" o2=mad(s2.x,k0,o2);\n"
" o2=mad(s2.y,k1,o2);\n"
" o2=mad(s2.z,k2,o2);\n"
" o2=mad(s2.w,k3,o2);\n"
" o3=mad(s3.x,k0,o3);\n"
" o3=mad(s3.y,k1,o3);\n"
" o3=mad(s3.z,k2,o3);\n"
" o3=mad(s3.w,k3,o3);\n"
" }\n"
" __private int out_y_idx=mul24(pos_y,4);\n"
" WI_F(uOutput,(int2)(srcY,out_y_idx),o0);\n"
" WI_F(uOutput,(int2)(srcY,out_y_idx+1),o1);\n"
" WI_F(uOutput,(int2)(srcY,out_y_idx+2),o2);\n"
" WI_F(uOutput,(int2)(srcY,out_y_idx+3),o3);\n"
" }\n"
"}\n"
"__kernel void gemmWinograd(__read_only image2d_t uInput,__read_only image2d_t uKernel,__write_only image2d_t uOutput,\n"
" __private const int unitWidth,__private const int unitHeight,__private const int dstChannelC4,__private const int multiLength,__private const int alpha2) {\n"
" \n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1));\n"
" const int unitWidth4=(unitWidth+3)/4;\n"
" if (pos.x<unitWidth4*unitHeight && pos.y<alpha2*dstChannelC4) {\n"
" \n"
" const int pos_x=pos.x % unitWidth4;\n"
" const int pos_y=pos.x/unitWidth4;\n"
" const int pos_z=pos.y % dstChannelC4;\n"
" const int pos_w=pos.y/dstChannelC4;\n"
" FLOAT4 o0=(FLOAT4)(0);\n"
" FLOAT4 o1=(FLOAT4)(0);\n"
" FLOAT4 o2=(FLOAT4)(0);\n"
" FLOAT4 o3=(FLOAT4)(0);\n"
" int srcY=mad24(pos_w,unitHeight,pos_y);\n"
" int srcX=pos_x << 2;\n"
" for (int k=0; k<multiLength; ++k) {\n"
" __private int index=mul24(k,4);\n"
" __private int x_offset=mul24(k,unitWidth);\n"
" FLOAT4 k0=RI_F(uKernel,SAMPLER,(int2)(index,pos.y));\n"
" FLOAT4 k1=RI_F(uKernel,SAMPLER,(int2)(index+1,pos.y));\n"
" FLOAT4 k2=RI_F(uKernel,SAMPLER,(int2)(index+2,pos.y));\n"
" FLOAT4 k3=RI_F(uKernel,SAMPLER,(int2)(index+3,pos.y));\n"
" FLOAT4 s0=RI_F(uInput,SAMPLER,(int2)(srcX+x_offset,srcY));\n"
" FLOAT4 s1=RI_F(uInput,SAMPLER,(int2)(srcX+x_offset+1,srcY));\n"
" FLOAT4 s2=RI_F(uInput,SAMPLER,(int2)(srcX+x_offset+2,srcY));\n"
" FLOAT4 s3=RI_F(uInput,SAMPLER,(int2)(srcX+x_offset+3,srcY));\n"
" o0=mad(s0.x,k0,o0);\n"
" o0=mad(s0.y,k1,o0);\n"
" o0=mad(s0.z,k2,o0);\n"
" o0=mad(s0.w,k3,o0);\n"
" o1=mad(s1.x,k0,o1);\n"
" o1=mad(s1.y,k1,o1);\n"
" o1=mad(s1.z,k2,o1);\n"
" o1=mad(s1.w,k3,o1);\n"
" o2=mad(s2.x,k0,o2);\n"
" o2=mad(s2.y,k1,o2);\n"
" o2=mad(s2.z,k2,o2);\n"
" o2=mad(s2.w,k3,o2);\n"
" o3=mad(s3.x,k0,o3);\n"
" o3=mad(s3.y,k1,o3);\n"
" o3=mad(s3.z,k2,o3);\n"
" o3=mad(s3.w,k3,o3);\n"
" }\n"
" __private int out_y_idx=mad24(pos_z,unitHeight,pos_y);\n"
" __private int out_x_idx=mad24(pos_w,unitWidth,srcX);\n"
" const int remain=unitWidth-srcX;\n"
" if(remain >= 4){\n"
" WI_F(uOutput,(int2)(out_x_idx,out_y_idx),o0);\n"
" WI_F(uOutput,(int2)(out_x_idx+1,out_y_idx),o1);\n"
" WI_F(uOutput,(int2)(out_x_idx+2,out_y_idx),o2);\n"
" WI_F(uOutput,(int2)(out_x_idx+3,out_y_idx),o3);\n"
" }else if(remain == 3){\n"
" WI_F(uOutput,(int2)(out_x_idx,out_y_idx),o0);\n"
" WI_F(uOutput,(int2)(out_x_idx+1,out_y_idx),o1);\n"
" WI_F(uOutput,(int2)(out_x_idx+2,out_y_idx),o2);\n"
" }else if(remain == 2){\n"
" WI_F(uOutput,(int2)(out_x_idx,out_y_idx),o0);\n"
" WI_F(uOutput,(int2)(out_x_idx+1,out_y_idx),o1);\n"
" }else if(remain == 1){\n"
" WI_F(uOutput,(int2)(out_x_idx,out_y_idx),o0);\n"
" }\n"
" }\n"
"}\n"
"__kernel void gemmWinogradW2(__read_only image2d_t uInput,__read_only image2d_t uKernel,__write_only image2d_t uOutput,\n"
" __private const int unitWidth,__private const int unitHeight,__private const int dstChannelC4,__private const int multiLength,__private const int alpha2) {\n"
" \n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1));\n"
" const int unitWidth8=(unitWidth+7)/8;\n"
" if (pos.x<unitWidth8*unitHeight && pos.y<alpha2*dstChannelC4) {\n"
" \n"
" const int pos_x=pos.x % unitWidth8;\n"
" const int pos_y=pos.x/unitWidth8;\n"
" const int pos_z=pos.y % dstChannelC4;\n"
" const int pos_w=pos.y/dstChannelC4;\n"
" FLOAT4 o0=(FLOAT4)(0);\n"
" FLOAT4 o1=(FLOAT4)(0);\n"
" FLOAT4 o2=(FLOAT4)(0);\n"
" FLOAT4 o3=(FLOAT4)(0);\n"
" FLOAT4 o4=(FLOAT4)(0);\n"
" FLOAT4 o5=(FLOAT4)(0);\n"
" FLOAT4 o6=(FLOAT4)(0);\n"
" FLOAT4 o7=(FLOAT4)(0);\n"
" int srcY=mad24(pos_w,unitHeight,pos_y);\n"
" int srcX=pos_x << 3;\n"
" for (int k=0; k<multiLength; ++k) {\n"
" __private int index=mul24(k,4);\n"
" __private int x_offset=mul24(k,unitWidth);\n"
" FLOAT4 k0=RI_F(uKernel,SAMPLER,(int2)(index,pos.y));\n"
" FLOAT4 k1=RI_F(uKernel,SAMPLER,(int2)(index+1,pos.y));\n"
" FLOAT4 k2=RI_F(uKernel,SAMPLER,(int2)(index+2,pos.y));\n"
" FLOAT4 k3=RI_F(uKernel,SAMPLER,(int2)(index+3,pos.y));\n"
" FLOAT4 s0=RI_F(uInput,SAMPLER,(int2)(srcX+x_offset,srcY));\n"
" FLOAT4 s1=RI_F(uInput,SAMPLER,(int2)(srcX+x_offset+1,srcY));\n"
" FLOAT4 s2=RI_F(uInput,SAMPLER,(int2)(srcX+x_offset+2,srcY));\n"
" FLOAT4 s3=RI_F(uInput,SAMPLER,(int2)(srcX+x_offset+3,srcY));\n"
" FLOAT4 s4=RI_F(uInput,SAMPLER,(int2)(srcX+x_offset+4,srcY));\n"
" FLOAT4 s5=RI_F(uInput,SAMPLER,(int2)(srcX+x_offset+5,srcY));\n"
" FLOAT4 s6=RI_F(uInput,SAMPLER,(int2)(srcX+x_offset+6,srcY));\n"
" FLOAT4 s7=RI_F(uInput,SAMPLER,(int2)(srcX+x_offset+7,srcY));\n"
" o0=mad(s0.x,k0,o0);\n"
" o0=mad(s0.y,k1,o0);\n"
" o0=mad(s0.z,k2,o0);\n"
" o0=mad(s0.w,k3,o0);\n"
" o1=mad(s1.x,k0,o1);\n"
" o1=mad(s1.y,k1,o1);\n"
" o1=mad(s1.z,k2,o1);\n"
" o1=mad(s1.w,k3,o1);\n"
" o2=mad(s2.x,k0,o2);\n"
" o2=mad(s2.y,k1,o2);\n"
" o2=mad(s2.z,k2,o2);\n"
" o2=mad(s2.w,k3,o2);\n"
" o3=mad(s3.x,k0,o3);\n"
" o3=mad(s3.y,k1,o3);\n"
" o3=mad(s3.z,k2,o3);\n"
" o3=mad(s3.w,k3,o3);\n"
" \n"
" o4=mad(s4.x,k0,o4);\n"
" o4=mad(s4.y,k1,o4);\n"
" o4=mad(s4.z,k2,o4);\n"
" o4=mad(s4.w,k3,o4);\n"
" o5=mad(s5.x,k0,o5);\n"
" o5=mad(s5.y,k1,o5);\n"
" o5=mad(s5.z,k2,o5);\n"
" o5=mad(s5.w,k3,o5);\n"
" o6=mad(s6.x,k0,o6);\n"
" o6=mad(s6.y,k1,o6);\n"
" o6=mad(s6.z,k2,o6);\n"
" o6=mad(s6.w,k3,o6);\n"
" o7=mad(s7.x,k0,o7);\n"
" o7=mad(s7.y,k1,o7);\n"
" o7=mad(s7.z,k2,o7);\n"
" o7=mad(s7.w,k3,o7);\n"
" }\n"
" __private int out_y_idx=mad24(pos_z,unitHeight,pos_y);\n"
" __private int out_x_idx=mad24(pos_w,unitWidth,srcX);\n"
" const int remain=unitWidth-srcX;\n"
" if(remain >= 8){\n"
" WI_F(uOutput,(int2)(out_x_idx,out_y_idx),o0);\n"
" WI_F(uOutput,(int2)(out_x_idx+1,out_y_idx),o1);\n"
" WI_F(uOutput,(int2)(out_x_idx+2,out_y_idx),o2);\n"
" WI_F(uOutput,(int2)(out_x_idx+3,out_y_idx),o3);\n"
" WI_F(uOutput,(int2)(out_x_idx+4,out_y_idx),o4);\n"
" WI_F(uOutput,(int2)(out_x_idx+5,out_y_idx),o5);\n"
" WI_F(uOutput,(int2)(out_x_idx+6,out_y_idx),o6);\n"
" WI_F(uOutput,(int2)(out_x_idx+7,out_y_idx),o7);\n"
" }else if(remain == 7){\n"
" WI_F(uOutput,(int2)(out_x_idx,out_y_idx),o0);\n"
" WI_F(uOutput,(int2)(out_x_idx+1,out_y_idx),o1);\n"
" WI_F(uOutput,(int2)(out_x_idx+2,out_y_idx),o2);\n"
" WI_F(uOutput,(int2)(out_x_idx+3,out_y_idx),o3);\n"
" WI_F(uOutput,(int2)(out_x_idx+4,out_y_idx),o4);\n"
" WI_F(uOutput,(int2)(out_x_idx+5,out_y_idx),o5);\n"
" WI_F(uOutput,(int2)(out_x_idx+6,out_y_idx),o6);\n"
" }else if(remain == 6){\n"
" WI_F(uOutput,(int2)(out_x_idx,out_y_idx),o0);\n"
" WI_F(uOutput,(int2)(out_x_idx+1,out_y_idx),o1);\n"
" WI_F(uOutput,(int2)(out_x_idx+2,out_y_idx),o2);\n"
" WI_F(uOutput,(int2)(out_x_idx+3,out_y_idx),o3);\n"
" WI_F(uOutput,(int2)(out_x_idx+4,out_y_idx),o4);\n"
" WI_F(uOutput,(int2)(out_x_idx+5,out_y_idx),o5);\n"
" }else if(remain == 5){\n"
" WI_F(uOutput,(int2)(out_x_idx,out_y_idx),o0);\n"
" WI_F(uOutput,(int2)(out_x_idx+1,out_y_idx),o1);\n"
" WI_F(uOutput,(int2)(out_x_idx+2,out_y_idx),o2);\n"
" WI_F(uOutput,(int2)(out_x_idx+3,out_y_idx),o3);\n"
" WI_F(uOutput,(int2)(out_x_idx+4,out_y_idx),o4);\n"
" }else if(remain == 4){\n"
" WI_F(uOutput,(int2)(out_x_idx,out_y_idx),o0);\n"
" WI_F(uOutput,(int2)(out_x_idx+1,out_y_idx),o1);\n"
" WI_F(uOutput,(int2)(out_x_idx+2,out_y_idx),o2);\n"
" WI_F(uOutput,(int2)(out_x_idx+3,out_y_idx),o3);\n"
" }else if(remain == 3){\n"
" WI_F(uOutput,(int2)(out_x_idx,out_y_idx),o0);\n"
" WI_F(uOutput,(int2)(out_x_idx+1,out_y_idx),o1);\n"
" WI_F(uOutput,(int2)(out_x_idx+2,out_y_idx),o2);\n"
" }else if(remain == 2){\n"
" WI_F(uOutput,(int2)(out_x_idx,out_y_idx),o0);\n"
" WI_F(uOutput,(int2)(out_x_idx+1,out_y_idx),o1);\n"
" }else if(remain == 1){\n"
" WI_F(uOutput,(int2)(out_x_idx,out_y_idx),o0);\n"
" }\n"
" }\n"
"}\n"
"#ifdef INPUT_CHANNEL_LEAVE\n"
" #define PADZEROSVEC(k, channel, data0, data1, data2, data3) "" data0 = (k << 2) < channel ? data0 : 0; "" data1 = (k << 2) + 1 < channel ? data1 : 0; "" data2 = (k << 2) + 2 < channel ? data2 : 0; "" data3=(k << 2)+3<channel ? data3 : 0;\n"
"#else\n"
" #define PADZEROSVEC(k,channel,data0,data1,data2,data3)\n"
"#endif\n"
"__kernel void gemm_conv(GLOBAL_SIZE_DIM2\n"
" __read_only image2d_t input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *weight,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar *weight,\n"
" __global const float *dequantScaleOffset,\n"
"#else\n"
" __global const FLOAT *weight,\n"
"#endif\n"
" __read_only image2d_t bias,\n"
" __write_only image2d_t output,\n"
" __private const int dstChannelC4,\n"
" __private const int srcChannelC4,\n"
" __private const int batch\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" ,__private const int blockDim\n"
" ,__private const int srcChannel\n"
"#endif\n"
") {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1)); //cout/4,b\n"
" UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n"
" FLOAT4 out=RI_F(bias,SAMPLER,(int2)(pos.x,0));\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" int weight_offset=pos.x*16;\n"
" int weight_oc_offset=dstChannelC4*16;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int weight_offset=pos.x*8;\n"
" int weight_oc_offset=dstChannelC4*8;\n"
"#else\n"
" int weight_offset=pos.x*16;\n"
" int weight_oc_offset=dstChannelC4*16;\n"
"#endif\n"
" for (int k=0; k<srcChannelC4; ++k) {\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int kindex=(k*4)/blockDim*dstChannelC4*8;\n"
" COMPUTE_FLOAT8 ScaleOffset=CONVERT_COMPUTE_FLOAT8(vload8(pos.x,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT16 scale=(COMPUTE_FLOAT16)(ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6,\n"
" ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6,\n"
" ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6,\n"
" ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6);\n"
" COMPUTE_FLOAT16 offset=(COMPUTE_FLOAT16)(ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7,\n"
" ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7,\n"
" ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7,\n"
" ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7);\n"
"#endif\n"
" FLOAT4 in=RI_F(input,SAMPLER,(int2)(k,pos.y));\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" FLOAT16 weights=CONVERT_FLOAT16(vload16(0,weight+weight_offset+k*weight_oc_offset))*scale+offset;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" uchar8 charWeightsInt4=vload8(0,weight+weight_offset+k*weight_oc_offset);\n"
" char16 charWeights=0;\n"
" charWeights.s0=(charWeightsInt4.s0 >> 4)-8;\n"
" charWeights.s1=(charWeightsInt4.s0 & 15)-8;\n"
" charWeights.s2=(charWeightsInt4.s1 >> 4)-8;\n"
" charWeights.s3=(charWeightsInt4.s1 & 15)-8;\n"
" charWeights.s4=(charWeightsInt4.s2 >> 4)-8;\n"
" charWeights.s5=(charWeightsInt4.s2 & 15)-8;\n"
" charWeights.s6=(charWeightsInt4.s3 >> 4)-8;\n"
" charWeights.s7=(charWeightsInt4.s3 & 15)-8;\n"
" charWeights.s8=(charWeightsInt4.s4 >> 4)-8;\n"
" charWeights.s9=(charWeightsInt4.s4 & 15)-8;\n"
" charWeights.sa=(charWeightsInt4.s5 >> 4)-8;\n"
" charWeights.sb=(charWeightsInt4.s5 & 15)-8;\n"
" charWeights.sc=(charWeightsInt4.s6 >> 4)-8;\n"
" charWeights.sd=(charWeightsInt4.s6 & 15)-8;\n"
" charWeights.se=(charWeightsInt4.s7 >> 4)-8;\n"
" charWeights.sf=(charWeightsInt4.s7 & 15)-8;\n"
" FLOAT16 weights=CONVERT_FLOAT16(charWeights)*scale+offset;\n"
" \n"
"#else\n"
" FLOAT16 weights=vload16(0,weight+weight_offset+k*weight_oc_offset);\n"
"#endif\n"
" PADZEROSVEC(k,srcChannel,weights.s0123,weights.s4567,weights.s89ab,weights.scdef);\n"
" \n"
" out=mad((FLOAT4)in.x,(FLOAT4)weights.s0123,out);\n"
" out=mad((FLOAT4)in.y,(FLOAT4)weights.s4567,out);\n"
" out=mad((FLOAT4)in.z,(FLOAT4)weights.s89ab,out);\n"
" out=mad((FLOAT4)in.w,(FLOAT4)weights.scdef,out);\n"
" }\n"
" \n"
"#ifdef RELU\n"
" out=fmax(out,(FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out=clamp(out,(FLOAT4)0,(FLOAT4)6);\n"
"#endif\n"
" WI_F(output,(int2)(pos.x,pos.y),out);\n"
"}\n"
"__kernel void gemm_conv_b2(GLOBAL_SIZE_DIM2\n"
" __read_only image2d_t input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *weight,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar *weight,\n"
" __global const float *dequantScaleOffset,\n"
"#else\n"
" __global const FLOAT *weight,\n"
"#endif\n"
" __read_only image2d_t bias,\n"
" __write_only image2d_t output,\n"
" __private const int dstChannelC4,\n"
" __private const int srcChannelC4,\n"
" __private const int batch\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" ,__private const int blockDim\n"
" ,__private const int srcChannel\n"
"#endif\n"
") {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1)); //cout/4,b\n"
" UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n"
" int pos_x=pos.x << 2;\n"
" int pos_y=pos.y << 1;\n"
" FLOAT4 bias0=RI_F(bias,SAMPLER,(int2)(pos.x,0));\n"
" FLOAT4 out0=bias0,out1=bias0;\n"
" \n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" int weight_offset=pos.x*16;\n"
" int weight_oc_offset=dstChannelC4*16;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int weight_offset=pos.x*8;\n"
" int weight_oc_offset=dstChannelC4*8;\n"
"#else\n"
" int weight_offset=pos.x*16;\n"
" int weight_oc_offset=dstChannelC4*16;\n"
"#endif\n"
" for (int k=0; k<srcChannelC4; ++k) {\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int kindex=(k*4)/blockDim*dstChannelC4*8;\n"
" COMPUTE_FLOAT8 ScaleOffset=CONVERT_COMPUTE_FLOAT8(vload8(pos.x,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT16 scale=(COMPUTE_FLOAT16)(ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6,\n"
" ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6,\n"
" ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6,\n"
" ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6);\n"
" COMPUTE_FLOAT16 offset=(COMPUTE_FLOAT16)(ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7,\n"
" ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7,\n"
" ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7,\n"
" ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7);\n"
"#endif\n"
" FLOAT4 in0=RI_F(input,SAMPLER,(int2)(k,pos_y));\n"
" FLOAT4 in1=RI_F(input,SAMPLER,(int2)(k,pos_y+1));\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" FLOAT16 weights=CONVERT_FLOAT16(vload16(0,weight+weight_offset+k*weight_oc_offset))*scale+offset;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" uchar8 charWeightsInt4=vload8(0,weight+weight_offset+k*weight_oc_offset);\n"
" char16 charWeights=0;\n"
" charWeights.s0=(charWeightsInt4.s0 >> 4)-8;\n"
" charWeights.s1=(charWeightsInt4.s0 & 15)-8;\n"
" charWeights.s2=(charWeightsInt4.s1 >> 4)-8;\n"
" charWeights.s3=(charWeightsInt4.s1 & 15)-8;\n"
" charWeights.s4=(charWeightsInt4.s2 >> 4)-8;\n"
" charWeights.s5=(charWeightsInt4.s2 & 15)-8;\n"
" charWeights.s6=(charWeightsInt4.s3 >> 4)-8;\n"
" charWeights.s7=(charWeightsInt4.s3 & 15)-8;\n"
" charWeights.s8=(charWeightsInt4.s4 >> 4)-8;\n"
" charWeights.s9=(charWeightsInt4.s4 & 15)-8;\n"
" charWeights.sa=(charWeightsInt4.s5 >> 4)-8;\n"
" charWeights.sb=(charWeightsInt4.s5 & 15)-8;\n"
" charWeights.sc=(charWeightsInt4.s6 >> 4)-8;\n"
" charWeights.sd=(charWeightsInt4.s6 & 15)-8;\n"
" charWeights.se=(charWeightsInt4.s7 >> 4)-8;\n"
" charWeights.sf=(charWeightsInt4.s7 & 15)-8;\n"
" FLOAT16 weights=CONVERT_FLOAT16(charWeights)*scale+offset;\n"
"#else\n"
" FLOAT16 weights=vload16(0,weight+weight_offset+k*weight_oc_offset);\n"
"#endif\n"
" PADZEROSVEC(k,srcChannel,weights.s0123,weights.s4567,weights.s89ab,weights.scdef);\n"
" \n"
" out0=mad((FLOAT4)in0.x,(FLOAT4)weights.s0123,out0);\n"
" out0=mad((FLOAT4)in0.y,(FLOAT4)weights.s4567,out0);\n"
" out0=mad((FLOAT4)in0.z,(FLOAT4)weights.s89ab,out0);\n"
" out0=mad((FLOAT4)in0.w,(FLOAT4)weights.scdef,out0);\n"
" \n"
" out1=mad((FLOAT4)in1.x,(FLOAT4)weights.s0123,out1);\n"
" out1=mad((FLOAT4)in1.y,(FLOAT4)weights.s4567,out1);\n"
" out1=mad((FLOAT4)in1.z,(FLOAT4)weights.s89ab,out1);\n"
" out1=mad((FLOAT4)in1.w,(FLOAT4)weights.scdef,out1);\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(FLOAT4)0);\n"
" out1=fmax(out1,(FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n"
" out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n"
"#endif\n"
" WI_F(output,(int2)(pos.x,pos_y),out0);\n"
" if(pos_y+1<batch)\n"
" WI_F(output,(int2)(pos.x,pos_y+1),out1);\n"
"}\n"
;
const char* depthwise_deconv2d = 
"#define READ_INPUT_IMAGE(i, base) "" int in_width_value##i = in_width##i + base; "" in_width_value##i = "" select(in_idx + in_width_value##i, -1, (in_width_value##i < 0 || in_width_value##i >= input_shape.y)); "" in##i=read_imagef(input,SAMPLER,(int2)(in_width_value##i,in_hb_value));\n"
"#define CALCULATE_OUTPUT(i) "" out##i = mad(in##i.x, weights0, out##i); "" out##i = mad(in##i.y, weights1, out##i); "" out##i = mad(in##i.z, weights2, out##i); "" out##i=mad(in##i.w,weights3,out##i);\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void depthwise_deconv2d(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,\n"
" __read_only image2d_t weights,\n"
" #ifndef NO_BIAS\n"
" __read_only image2d_t bias,\n"
" #endif\n"
" __write_only image2d_t output,\n"
" __private const int2 input_shape,\n"
" __private const int2 output_shape,\n"
" __private const int2 stride_shape,\n"
" __private const int2 align_shape,\n"
" __private const int2 padding_shape,\n"
" __private const int2 kernel_shape,\n"
" __private const int kernel_size,__private const int out_channel_blocks) {\n"
" const int out_channel_blocks_idx=get_global_id(0);\n"
" const int out_width_idx=get_global_id(1);\n"
" const int out_batch_height_idx=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(out_channel_blocks_idx,out_width_idx,out_batch_height_idx);\n"
" #ifndef NO_BIAS\n"
" float4 out0=read_imagef(bias,SAMPLER,(int2)(out_channel_blocks_idx,0));\n"
" #else\n"
" float4 out0=(float4)(0.0);\n"
" #endif\n"
" const int out_batch_idx=out_batch_height_idx/output_shape.x;\n"
" const int out_height_idx=out_batch_height_idx % output_shape.x;\n"
" int kernel_start_x=(out_width_idx+align_shape.y)/stride_shape.y;\n"
" int kernel_start_y=(out_height_idx+align_shape.x)/stride_shape.x;\n"
" int deal_kernel_width=kernel_shape.y-mad24(kernel_start_x,stride_shape.y,padding_shape.y)+out_width_idx-1;\n"
" int deal_kernel_height=kernel_shape.x-mad24(kernel_start_y,stride_shape.x,padding_shape.x)+out_height_idx-1;\n"
" int kernel_image_x;\n"
" float4 in0;\n"
" float4 weight;\n"
" int in_width0;\n"
" int in_idx,in_idy;\n"
" for (int k_y=deal_kernel_height,idx_h=kernel_start_y; k_y >= 0; k_y -= stride_shape.x,idx_h++) {\n"
" in_idy=mad24(out_batch_idx,input_shape.x,idx_h);\n"
" int in_hb_value=select(in_idy,-1,idx_h<0 || idx_h >= input_shape.x);\n"
" for (int k_x=deal_kernel_width,in_width_idx=kernel_start_x; k_x >= 0; k_x -= stride_shape.y,in_width_idx++) {\n"
" in_width0=in_width_idx;\n"
" in_idx=mul24(out_channel_blocks_idx,input_shape.y);\n"
" READ_INPUT_IMAGE(0,0);\n"
" kernel_image_x=mad24(k_y,kernel_shape.y,k_x);\n"
" weight=read_imagef(weights,SAMPLER,(int2)(kernel_image_x,out_channel_blocks_idx));\n"
" out0=mad(in0,weight,out0);\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(float4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(float4)0,(float4)6);\n"
"#endif\n"
" const int output_image_x=mad24(out_channel_blocks_idx,output_shape.y,out_width_idx);\n"
" write_imagef(output,(int2)(output_image_x,out_batch_height_idx),out0);\n"
" }\n"
"}\n"
;
const char* range = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_3_DIMS ""__private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void range(GLOBAL_SIZE_3_DIMS\n"
" __read_only image2d_t input0,\n"
" __read_only image2d_t input2,\n"
" __write_only image2d_t output,\n"
" __private const int width,\n"
" __private const int height,\n"
" __private const int channel,\n"
" __private const int channelBlock\n"
" ) {\n"
" const int width_idx=get_global_id(0);\n"
" const int height_idx=get_global_id(1);\n"
" const int batch_channel_idx=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(width_idx,height_idx,batch_channel_idx);\n"
" \n"
" const int batch_idx=batch_channel_idx/channelBlock;\n"
" const int channel_idx=batch_channel_idx % channelBlock;\n"
" \n"
" const int bh=batch_idx*height+height_idx;\n"
" const int cw=channel_idx*width+width_idx;\n"
" const int channel4=channel_idx << 2;\n"
" int index=(((batch_idx*channel)+channel4)*height+height_idx)*width+width_idx;\n"
" int size=height*width;\n"
" int4 index4=(int4)(index,index+size,index+size*2,index+size*3);\n"
" INPUT_TYPE_I start=RI_DATA(input0,SAMPLER,(int2)(0,0)).x;\n"
" INPUT_TYPE_I step=RI_DATA(input2,SAMPLER,(int2)(0,0)).x;\n"
" OUTPUT_TYPE_I4 value=(OUTPUT_TYPE_I4)start+CONVERT_OUTPUT_I4(index4)*(OUTPUT_TYPE_I4)step;\n"
" WI_DATA(output,(int2)(cw,bh),value);\n"
"}\n"
;
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* scale_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_2_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,\n"
"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n"
"__kernel void scale_buf(GLOBAL_SIZE_2_DIMS\n"
" __global const FLOAT* input,\n"
" __global const FLOAT* scale,\n"
"#ifdef BIAS\n"
" __global const FLOAT* bias,\n"
"#endif\n"
" __global FLOAT* output,\n"
" __private const int channelBlock,\n"
" __private const int batch,\n"
" __private const int inside) {\n"
" const int x=get_global_id(0); // inside(width*height)\n"
" const int y=get_global_id(1); // channelBlock*batch\n"
" \n"
" DEAL_NON_UNIFORM_DIM2(x,y);\n"
" const int out_c_idx=y % channelBlock;\n"
" const int out_b_idx=y/channelBlock;\n"
" const int offset=((out_b_idx+out_c_idx*batch)*inside+x)*4;\n"
" COMPUTE_FLOAT4 in_value=CONVERT_COMPUTE_FLOAT4(vload4(0,input+offset));\n"
" COMPUTE_FLOAT4 scale_value=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,scale));\n"
" #ifdef BIAS\n"
" COMPUTE_FLOAT4 bias_value=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias));\n"
" COMPUTE_FLOAT4 out_value=in_value*scale_value+bias_value;\n"
" #else\n"
" COMPUTE_FLOAT4 out_value=in_value*scale_value;\n"
" #endif\n"
" vstore4(CONVERT_FLOAT4(out_value),0,output+offset);\n"
"}\n"
;
#endif
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* matmul_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_2_DIMS ""__private const int global_size_dim0,__private const int global_size_dim1,\n"
"#define DEAL_NON_UNIFORM_DIM2(input1, input2) ""if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { ""return; ""}\n"
"__kernel void matmul_buf(GLOBAL_SIZE_2_DIMS __global const FLOAT* input_a,\n"
" __global const FLOAT* input_b,\n"
" #ifdef BIAS\n"
" __global const FLOAT* input_c,\n"
" #endif\n"
" __global FLOAT* output_c,\n"
" __private const int M,\n"
" __private const int N,\n"
" __private const int K) {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1)); // N M\n"
" DEAL_NON_UNIFORM_DIM2(pos.x,pos.y);\n"
" const int idn=pos.x << 2;\n"
" const int idm=pos.y << 2;\n"
" \n"
" COMPUTE_FLOAT4 out[4];\n"
" #ifdef BIAS\n"
" COMPUTE_FLOAT4 bias=CONVERT_COMPUTE_FLOAT4(vload4(0,input_c+idn));\n"
" #pragma unroll\n"
" for(int i=0; i<4; ++i){\n"
" out[i]=bias;\n"
" }\n"
" #else\n"
" #pragma unroll\n"
" for(int i=0; i<4; ++i){\n"
" out[i]=(COMPUTE_FLOAT4)0;\n"
" }\n"
" #endif\n"
" const int K4=(K+3)/4;\n"
" #ifdef K_LEAVE\n"
" const int loop_end=max(K4-1,0);\n"
" const int remain=K-loop_end*4;\n"
" #else\n"
" const int loop_end=K4;\n"
" #endif\n"
" \n"
" #ifdef TRANSPOSE_A\n"
" __global const FLOAT* input_a_offset=input_a+idm; // K x M\n"
" #else\n"
" __global const FLOAT* input_a_offset=input_a+idm*K; // M x K\n"
" #endif\n"
" \n"
" #ifdef TRANSPOSE_B\n"
" __global const FLOAT* input_b_offset=input_b+idn*K; // N x K\n"
" #else\n"
" __global const FLOAT* input_b_offset=input_b+idn; // K x N\n"
" #endif\n"
" \n"
" for (int k=0; k<loop_end; ++k) {\n"
" int kindex=k << 2;\n"
" COMPUTE_FLOAT4 A[4]; // m4 x k4\n"
" COMPUTE_FLOAT4 B[4]; // k4 x n4\n"
" #ifdef TRANSPOSE_A\n"
" {\n"
" COMPUTE_FLOAT4 tmp0=CONVERT_COMPUTE_FLOAT4(vload4(0,input_a_offset+kindex*M));\n"
" COMPUTE_FLOAT4 tmp1=CONVERT_COMPUTE_FLOAT4(vload4(0,input_a_offset+(kindex+1)*M));\n"
" COMPUTE_FLOAT4 tmp2=CONVERT_COMPUTE_FLOAT4(vload4(0,input_a_offset+(kindex+2)*M));\n"
" COMPUTE_FLOAT4 tmp3=CONVERT_COMPUTE_FLOAT4(vload4(0,input_a_offset+(kindex+3)*M));\n"
" \n"
" A[0]=(COMPUTE_FLOAT4)(tmp0.x,tmp1.x,tmp2.x,tmp3.x);\n"
" A[1]=(COMPUTE_FLOAT4)(tmp0.y,tmp1.y,tmp2.y,tmp3.y);\n"
" A[2]=(COMPUTE_FLOAT4)(tmp0.z,tmp1.z,tmp2.z,tmp3.z);\n"
" A[3]=(COMPUTE_FLOAT4)(tmp0.w,tmp1.w,tmp2.w,tmp3.w);\n"
" }\n"
" #else\n"
" A[0]=CONVERT_COMPUTE_FLOAT4(vload4(0,input_a_offset+kindex));\n"
" A[1]=CONVERT_COMPUTE_FLOAT4(vload4(0,input_a_offset+kindex+K));\n"
" A[2]=CONVERT_COMPUTE_FLOAT4(vload4(0,input_a_offset+kindex+2*K));\n"
" A[3]=CONVERT_COMPUTE_FLOAT4(vload4(0,input_a_offset+kindex+3*K));\n"
" #endif\n"
" \n"
" #ifdef TRANSPOSE_B\n"
" {\n"
" COMPUTE_FLOAT4 tmp0=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+kindex));\n"
" COMPUTE_FLOAT4 tmp1=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+kindex+K));\n"
" COMPUTE_FLOAT4 tmp2=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+kindex+2*K));\n"
" COMPUTE_FLOAT4 tmp3=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+kindex+3*K));\n"
" \n"
" B[0]=(COMPUTE_FLOAT4)(tmp0.x,tmp1.x,tmp2.x,tmp3.x);\n"
" B[1]=(COMPUTE_FLOAT4)(tmp0.y,tmp1.y,tmp2.y,tmp3.y);\n"
" B[2]=(COMPUTE_FLOAT4)(tmp0.z,tmp1.z,tmp2.z,tmp3.z);\n"
" B[3]=(COMPUTE_FLOAT4)(tmp0.w,tmp1.w,tmp2.w,tmp3.w);\n"
" }\n"
" #else\n"
" B[0]=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+kindex*N));\n"
" B[1]=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+(kindex+1)*N));\n"
" B[2]=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+(kindex+2)*N));\n"
" B[3]=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+(kindex+3)*N));\n"
" #endif\n"
" \n"
" #pragma unroll\n"
" for (int vec_m=0; vec_m<4; ++vec_m){\n"
" out[vec_m]=mad((COMPUTE_FLOAT4)A[vec_m].x,B[0],out[vec_m]);\n"
" out[vec_m]=mad((COMPUTE_FLOAT4)A[vec_m].y,B[1],out[vec_m]);\n"
" out[vec_m]=mad((COMPUTE_FLOAT4)A[vec_m].z,B[2],out[vec_m]);\n"
" out[vec_m]=mad((COMPUTE_FLOAT4)A[vec_m].w,B[3],out[vec_m]);\n"
" }\n"
" }\n"
" #ifdef K_LEAVE\n"
" for (int k=loop_end << 2; k<K; ++k){\n"
" COMPUTE_FLOAT4 A; // m4\n"
" COMPUTE_FLOAT4 B; // n4\n"
" #ifdef TRANSPOSE_A\n"
" A=CONVERT_COMPUTE_FLOAT4(vload4(0,input_a_offset+k*M));\n"
" #else\n"
" A.x=(COMPUTE_FLOAT)input_a_offset[k];\n"
" A.y=(COMPUTE_FLOAT)input_a_offset[k+K];\n"
" A.z=(COMPUTE_FLOAT)input_a_offset[k+2*K];\n"
" A.w=(COMPUTE_FLOAT)input_a_offset[k+3*K];\n"
" #endif\n"
" \n"
" #ifdef TRANSPOSE_B\n"
" B.x=(COMPUTE_FLOAT)input_b_offset[k];\n"
" B.y=(COMPUTE_FLOAT)input_b_offset[k+K];\n"
" B.z=(COMPUTE_FLOAT)input_b_offset[k+2*K];\n"
" B.w=(COMPUTE_FLOAT)input_b_offset[k+3*K];\n"
" #else\n"
" B=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+k*N));\n"
" #endif\n"
" out[0]=mad((COMPUTE_FLOAT4)A.x,B,out[0]);\n"
" out[1]=mad((COMPUTE_FLOAT4)A.y,B,out[1]);\n"
" out[2]=mad((COMPUTE_FLOAT4)A.z,B,out[2]);\n"
" out[3]=mad((COMPUTE_FLOAT4)A.w,B,out[3]);\n"
" }\n"
" #endif\n"
" \n"
" \n"
" const int out_offset=idm*N+idn;\n"
" #ifdef M_LEAVE\n"
" if(idm+3 >= M){\n"
" #ifdef N_LEAVE\n"
" if(idn+3 >= N){\n"
" for (int vec_m=0; vec_m<M-idm; ++vec_m){\n"
" COMPUTE_FLOAT *out_ptr=(COMPUTE_FLOAT*)&out[vec_m];\n"
" for(int vec_n=0; vec_n<N-idn; ++vec_n){\n"
" output_c[out_offset+vec_m*N+vec_n]=out_ptr[vec_n];\n"
" }\n"
" }\n"
" } else {\n"
" #endif\n"
" for (int vec_m=0; vec_m<M-idm; ++vec_m){\n"
" vstore4(CONVERT_FLOAT4(out[vec_m]),0,output_c+out_offset+vec_m*N);\n"
" }\n"
" \n"
" #ifdef N_LEAVE\n"
" }\n"
" #endif\n"
" } else{\n"
" #endif\n"
" #ifdef N_LEAVE\n"
" if(idn+3 >= N){\n"
" #pragma unroll\n"
" for (int vec_m=0; vec_m<4; ++vec_m){\n"
" COMPUTE_FLOAT *out_ptr=(COMPUTE_FLOAT*)&out[vec_m];\n"
" for(int vec_n=0; vec_n<N-idn; ++vec_n){\n"
" output_c[out_offset+vec_m*N+vec_n]=out_ptr[vec_n];\n"
" }\n"
" }\n"
" } else {\n"
" #endif\n"
" #pragma unroll\n"
" for (int vec_m=0; vec_m<4; ++vec_m){\n"
" vstore4(CONVERT_FLOAT4(out[vec_m]),0,output_c+out_offset+vec_m*N);\n"
" }\n"
" #ifdef N_LEAVE\n"
" }\n"
" #endif\n"
" #ifdef M_LEAVE\n"
" }\n"
" #endif\n"
"}\n"
;
#endif
const char* pooling = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void pooling(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,\n"
" __private const int2 input_shape,__private const int output_height,__private const int2 pad_shape,\n"
" __private const int2 stride_shape,\n"
" __private const int2 kernel_shape,\n"
" __write_only image2d_t output,\n"
" __write_only image2d_t rediceOutput) {\n"
" const int output_channel_idx=get_global_id(0);\n"
" const int output_width_idx=get_global_id(1);\n"
" const int output_batch_height_idx=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(output_channel_idx,output_width_idx,output_batch_height_idx);\n"
" const int output_width=global_size_dim1;\n"
" const int output_batch_idx=output_batch_height_idx/output_height;\n"
" const int output_height_idx=output_batch_height_idx-mul24(output_batch_idx,output_height);\n"
" const int input_start=mul24(output_batch_idx,input_shape.x);\n"
" const int input_height_start=mad24(output_height_idx,stride_shape.x,-pad_shape.x);\n"
" const int input_width_start=mad24(output_width_idx,stride_shape.y,-pad_shape.y);\n"
" const int input_channel_start=mul24(output_channel_idx,input_shape.y);\n"
"#ifdef POOL_AVG\n"
" FLOAT4 output_result=0;\n"
" for (int height=0; height<kernel_shape.x; height++) {\n"
" int input_height_idx=input_height_start+height;\n"
" input_height_idx =\n"
" select(input_start+input_height_idx,-1,(input_height_idx<0 || input_height_idx >= input_shape.x));\n"
" for (int width=0; width<kernel_shape.y; width++) {\n"
" int input_width_idx=input_width_start+width;\n"
" input_width_idx =\n"
" select(input_channel_start+input_width_idx,-1,(input_width_idx<0 || input_width_idx >= input_shape.y));\n"
" FLOAT4 input_data=RI_F(input,SAMPLER,(int2)(input_width_idx,input_height_idx));\n"
" output_result=output_result+input_data;\n"
" }\n"
" }\n"
" const int kernel_height_start=max(0,input_height_start);\n"
" const int kernel_width_start=max(0,input_width_start);\n"
" const int kernel_height_end=min(input_height_start+kernel_shape.x,input_shape.x);\n"
" const int kernel_width_end=min(input_width_start+kernel_shape.y,input_shape.y);\n"
" #ifdef COUNT_INCLUDE_PADDING\n"
" const int block_size=(min(input_height_start+kernel_shape.x,input_shape.x+pad_shape.x)-input_height_start)*(min(input_width_start+kernel_shape.y,input_shape.y+pad_shape.y)-input_width_start);\n"
" #else\n"
" const int block_size=mul24((kernel_height_end-kernel_height_start),(kernel_width_end-kernel_width_start));\n"
" #endif\n"
" const FLOAT block_float_req=(FLOAT)1.0f/(FLOAT)block_size;\n"
" output_result=output_result*block_float_req;\n"
"#else\n"
" FLOAT4 output_result=(FLOAT4)(-FLT_MAX);\n"
" #if RETURN_REDICE\n"
" int4 redice=(int4)0;\n"
" #endif\n"
" for (int height=0; height<kernel_shape.x; height++) {\n"
" int input_height_idx=input_height_start+height;\n"
" input_height_idx =\n"
" select(input_start+input_height_idx,-1,(input_height_idx<0 || input_height_idx >= input_shape.x));\n"
" if (input_height_idx != -1) {\n"
" for (int width=0; width<kernel_shape.y; width++) {\n"
" int input_width_idx=input_width_start+width;\n"
" input_width_idx=select(input_channel_start+input_width_idx,-1,\n"
" (input_width_idx<0 || input_width_idx >= input_shape.y));\n"
" if (input_width_idx != -1) {\n"
" FLOAT4 input_data=RI_F(input,SAMPLER,(int2)(input_width_idx,input_height_idx));\n"
" #if RETURN_REDICE\n"
" redice=input_data>output_result ? (int4)((input_height_start+height)*input_shape.y+input_width_start+width) : redice;\n"
" #endif\n"
" output_result=fmax(output_result,input_data);\n"
" }\n"
" }\n"
" }\n"
" }\n"
"#endif\n"
" const int output_channel_width_idx=mad24(output_channel_idx,output_width,output_width_idx);\n"
" WI_F(output,(int2)(output_channel_width_idx,output_batch_height_idx),output_result);\n"
" #if RETURN_REDICE\n"
" WI_F(rediceOutput,(int2)(output_channel_width_idx,output_batch_height_idx),CONVERT_FLOAT4(redice));\n"
" #endif\n"
"}\n"
"#ifdef LOCAL_SIZE\n"
"__kernel void global_pooling(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,\n"
" __private const int2 input_shape,__private const int output_height,__private const int2 pad_shape,\n"
" __private const int2 stride_shape,\n"
" __private const int2 kernel_shape,\n"
" __write_only image2d_t output,\n"
" __write_only image2d_t rediceOutput) {\n"
" const int local_id=get_local_id(0);\n"
" const int output_channel_idx=get_global_id(1);\n"
" const int output_batch_idx=get_global_id(2);\n"
"#ifdef POOL_AVG\n"
" FLOAT4 output_result=0;\n"
"#else\n"
" FLOAT4 output_result=(FLOAT4)(-FLT_MAX);\n"
"#if RETURN_REDICE\n"
" int4 redice=(int4)0;\n"
" int4 local rediceId[LOCAL_SIZE];\n"
"#endif\n"
"#endif\n"
" FLOAT4 local sum[LOCAL_SIZE];\n"
" int wc=output_channel_idx*input_shape.y;\n"
" int bh=output_batch_idx*input_shape.x;\n"
" for(int i=local_id; i<input_shape.x*input_shape.y; i+=LOCAL_SIZE){\n"
" int w=i % input_shape.y;;\n"
" int h=i/input_shape.y;\n"
" FLOAT4 in=RI_F(input,SAMPLER,(int2)(wc+w,bh+h));\n"
"#ifdef POOL_AVG\n"
" output_result += in;\n"
"#else\n"
" output_result=fmax(output_result,in);\n"
"#if RETURN_REDICE\n"
" redice=in>output_result ? (int4)(i) : redice;\n"
"#endif\n"
"#endif\n"
" }\n"
" \n"
" sum[local_id]=output_result;\n"
"#if RETURN_REDICE\n"
" rediceId[local_id]=redice;\n"
"#endif\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (local_id<i)\n"
"#ifdef POOL_AVG\n"
" sum[local_id]=sum[local_id]+sum[local_id+i];\n"
"#else\n"
" {\n"
" sum[local_id]=fmax(sum[local_id],sum[local_id+i]);\n"
"#if RETURN_REDICE\n"
" rediceId[local_id]=sum[local_id]>sum[local_id+i] ? rediceId[local_id] : rediceId[local_id+i];\n"
"#endif\n"
" }\n"
"#endif\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" output_result=sum[0];\n"
"#ifdef POOL_AVG\n"
" output_result /= (input_shape.x*input_shape.y);\n"
"#endif\n"
" WI_F(output,(int2)(output_channel_idx,output_batch_idx),output_result);\n"
" #if RETURN_REDICE\n"
" redice=rediceId[0];\n"
" WI_F(rediceOutput,(int2)(output_channel_idx,output_batch_idx),CONVERT_FLOAT4(redice));\n"
" #endif\n"
"}\n"
"#endif\n"
;
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* conv_2d_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0,__private const int global_size_dim1,\n"
"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n"
"#ifdef CONV_LOCAL_SIZE\n"
"__kernel\n"
"void conv_2d_1x1_local(__private const int out_w_blocks,\n"
" __global const FLOAT *input,\n"
" __global const FLOAT *kernel_ptr,\n"
" __global const FLOAT *bias_ptr,\n"
" __global FLOAT *output,\n"
" __private const int in_c_block,\n"
" __private const int batch,\n"
" __private const int out_h,\n"
" __private const int out_w,\n"
" __private const int out_c_block,\n"
" __private const int out_c_pack) {\n"
" const int lid=get_local_id(0);\n"
" const int out_c_w_idx=get_global_id(1); //c/4 w\n"
" const int out_b_h_idx=get_global_id(2); //b h\n"
" \n"
" COMPUTE_FLOAT4 local sum[CONV_LOCAL_SIZE];\n"
" \n"
" const int out_c_idx=out_c_w_idx/out_w_blocks;\n"
" const int out_w_idx=out_c_w_idx % out_w_blocks;\n"
" const int out_b_idx=out_b_h_idx/out_h; // equal to in_b_idx\n"
" const int out_h_idx=out_b_h_idx % out_h; // equal to in_h_idx\n"
" COMPUTE_FLOAT4 bias0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias_ptr));\n"
" COMPUTE_FLOAT4 out0=(COMPUTE_FLOAT4)0;\n"
" int offset=out_c_idx*4;\n"
" int inp_offset=(((out_b_idx+in_c_block*batch)*out_h+out_h_idx)* out_w+out_w_idx) << 2;\n"
" \n"
" const int inp_add=batch*out_h*out_w*4;\n"
" for (ushort in_channel_block_idx=lid; in_channel_block_idx<in_c_block; in_channel_block_idx+=CONV_LOCAL_SIZE) {\n"
" \n"
" int offset=mad24(in_channel_block_idx*4,out_c_pack,out_c_idx*4);\n"
" COMPUTE_FLOAT4 in0=CONVERT_COMPUTE_FLOAT4(vload4(0,input+inp_offset+in_channel_block_idx*inp_add));\n"
" COMPUTE_FLOAT4 weights0=CONVERT_COMPUTE_FLOAT4(vload4(0,kernel_ptr+offset));\n"
" COMPUTE_FLOAT4 weights1=CONVERT_COMPUTE_FLOAT4(vload4(0,kernel_ptr+offset+out_c_pack));\n"
" COMPUTE_FLOAT4 weights2=CONVERT_COMPUTE_FLOAT4(vload4(0,kernel_ptr+offset+out_c_pack+out_c_pack));\n"
" COMPUTE_FLOAT4 weights3=CONVERT_COMPUTE_FLOAT4(vload4(0,kernel_ptr+offset+out_c_pack+out_c_pack+out_c_pack));\n"
" out0=mad(in0.x,weights0,out0);\n"
" out0=mad(in0.y,weights1,out0);\n"
" out0=mad(in0.z,weights2,out0);\n"
" out0=mad(in0.w,weights3,out0);\n"
" }\n"
" \n"
" sum[lid]=out0;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=CONV_LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=sum[lid]+sum[lid+i];\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" out0=sum[0]+bias0;\n"
" if(lid == 0){\n"
"#ifdef RELU\n"
" out0=fmax(out0,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" const int out_offset=(((out_b_idx+out_c_idx*batch)*out_h+out_h_idx)* out_w+out_w_idx)*4;\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" }\n"
"}\n"
"#endif\n"
"__kernel\n"
"void conv_2d_1x1_c4h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,\n"
" __global const FLOAT *input,\n"
" __global const FLOAT *kernel_ptr,\n"
" __global const FLOAT *bias_ptr,\n"
" __global FLOAT *output,\n"
" __private const int in_c_block,\n"
" __private const int out_h,\n"
" __private const int out_w,\n"
" __private const int out_b,\n"
" __private const int out_c_block,\n"
" __private const int out_c_pack) {\n"
" const int out_c_w_idx=get_global_id(0); //c/4 w\n"
" const int out_b_h_idx=get_global_id(1); //b h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int out_c_idx=out_c_w_idx/out_w_blocks;\n"
" const int out_w_idx=out_c_w_idx % out_w_blocks;\n"
" const int out_b_idx=out_b_h_idx/out_h; // equal to in_b_idx\n"
" const int out_h_idx=out_b_h_idx % out_h; // equal to in_h_idx\n"
" const int out_w4_idx=mul24(out_w_idx,4);\n"
" COMPUTE_FLOAT4 out0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias_ptr));\n"
" COMPUTE_FLOAT4 out1=out0;\n"
" COMPUTE_FLOAT4 out2=out0;\n"
" COMPUTE_FLOAT4 out3=out0;\n"
" const int intput_width_idx0=out_w4_idx;\n"
" int inp_offset=((out_b_idx*out_h+out_h_idx)* out_w+intput_width_idx0) << 2;\n"
" int offset=out_c_idx*4;\n"
" const int inp_add=out_b*out_h*out_w*4;\n"
" for (ushort in_channel_block_idx=0; in_channel_block_idx<in_c_block; ++in_channel_block_idx) {\n"
" \n"
" COMPUTE_FLOAT4 in0=CONVERT_COMPUTE_FLOAT4(vload4(0,input+inp_offset));\n"
" COMPUTE_FLOAT4 in1=CONVERT_COMPUTE_FLOAT4(vload4(1,input+inp_offset));\n"
" COMPUTE_FLOAT4 in2=CONVERT_COMPUTE_FLOAT4(vload4(2,input+inp_offset));\n"
" COMPUTE_FLOAT4 in3=CONVERT_COMPUTE_FLOAT4(vload4(3,input+inp_offset));\n"
" COMPUTE_FLOAT4 weights0=CONVERT_COMPUTE_FLOAT4(vload4(0,kernel_ptr+offset));\n"
" COMPUTE_FLOAT4 weights1=CONVERT_COMPUTE_FLOAT4(vload4(0,kernel_ptr+offset+out_c_pack));\n"
" COMPUTE_FLOAT4 weights2=CONVERT_COMPUTE_FLOAT4(vload4(0,kernel_ptr+offset+out_c_pack+out_c_pack));\n"
" COMPUTE_FLOAT4 weights3=CONVERT_COMPUTE_FLOAT4(vload4(0,kernel_ptr+offset+out_c_pack+out_c_pack+out_c_pack));\n"
" out0=mad(in0.x,weights0,out0);\n"
" out0=mad(in0.y,weights1,out0);\n"
" out0=mad(in0.z,weights2,out0);\n"
" out0=mad(in0.w,weights3,out0);\n"
" \n"
" out1=mad(in1.x,weights0,out1);\n"
" out1=mad(in1.y,weights1,out1);\n"
" out1=mad(in1.z,weights2,out1);\n"
" out1=mad(in1.w,weights3,out1);\n"
" \n"
" out2=mad(in2.x,weights0,out2);\n"
" out2=mad(in2.y,weights1,out2);\n"
" out2=mad(in2.z,weights2,out2);\n"
" out2=mad(in2.w,weights3,out2);\n"
" \n"
" out3=mad(in3.x,weights0,out3);\n"
" out3=mad(in3.y,weights1,out3);\n"
" out3=mad(in3.z,weights2,out3);\n"
" out3=mad(in3.w,weights3,out3);\n"
" \n"
" offset += 4*out_c_pack;\n"
" inp_offset += inp_add;\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(COMPUTE_FLOAT4)0);\n"
" out1=fmax(out1,(COMPUTE_FLOAT4)0);\n"
" out2=fmax(out2,(COMPUTE_FLOAT4)0);\n"
" out3=fmax(out3,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out1=clamp(out1,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out2=clamp(out2,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out3=clamp(out3,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" const int out_offset=(((out_b_idx+out_c_idx*out_b)*out_h+out_h_idx)* out_w+out_w4_idx)*4;\n"
"#ifdef BLOCK_LEAVE\n"
" const int remain=out_w-out_w4_idx;\n"
" if (remain >= 4) {\n"
" vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out0,out1,out2,out3)),0,output+out_offset);\n"
" } else if (remain == 3) {\n"
" vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out0,out1)),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out2),2,output+out_offset);\n"
" } else if (remain == 2) {\n"
" vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out0,out1)),0,output+out_offset);\n"
" } else if (remain == 1) {\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" }\n"
"#else\n"
" vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out0,out1,out2,out3)),0,output+out_offset);\n"
"#endif\n"
"}\n"
"__kernel\n"
"void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,\n"
" __global const FLOAT *input,\n"
" __global const FLOAT *kernel_ptr,\n"
" __global const FLOAT *bias_ptr,\n"
" __global FLOAT *output,\n"
" __private const int in_c_block,\n"
" __private const int out_h,\n"
" __private const int out_w,\n"
" __private const int out_b,\n"
" __private const int out_c_block,\n"
" __private const int out_c_pack) {\n"
" const int out_c_w_idx=get_global_id(0); //c/8 w/4\n"
" const int out_b_h_idx=get_global_id(1); //b h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int out_c_idx=out_c_w_idx/out_w_blocks;\n"
" const int out_w_idx=out_c_w_idx % out_w_blocks;\n"
" const int out_b_idx=out_b_h_idx/out_h;//equal to in_b_idx\n"
" const int out_h_idx=out_b_h_idx % out_h;//equal to in_h_idx\n"
" const int out_w4_idx=mul24(out_w_idx,4);\n"
" COMPUTE_FLOAT4 out0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx<<1,bias_ptr));\n"
" COMPUTE_FLOAT4 out1=out0;\n"
" COMPUTE_FLOAT4 out2=out0;\n"
" COMPUTE_FLOAT4 out3=out0;\n"
" \n"
" COMPUTE_FLOAT4 out4=CONVERT_COMPUTE_FLOAT4(vload4((out_c_idx<<1)+1,bias_ptr));\n"
" COMPUTE_FLOAT4 out5=out4;\n"
" COMPUTE_FLOAT4 out6=out4;\n"
" COMPUTE_FLOAT4 out7=out4;\n"
" const int intput_width_idx0=out_w4_idx;\n"
" int inp_offset=((out_b_idx*out_h+out_h_idx)* out_w+intput_width_idx0)<<2;\n"
" int offset=out_c_idx*8;\n"
" const int inp_add=out_b*out_h*out_w*4;\n"
" for (int in_channel_block_idx=0; in_channel_block_idx<in_c_block; ++in_channel_block_idx) {\n"
" \n"
" COMPUTE_FLOAT4 in0=CONVERT_COMPUTE_FLOAT4(vload4(0,input+inp_offset));\n"
" COMPUTE_FLOAT4 in1=CONVERT_COMPUTE_FLOAT4(vload4(1,input+inp_offset));\n"
" COMPUTE_FLOAT4 in2=CONVERT_COMPUTE_FLOAT4(vload4(2,input+inp_offset));\n"
" COMPUTE_FLOAT4 in3=CONVERT_COMPUTE_FLOAT4(vload4(3,input+inp_offset));\n"
" \n"
" COMPUTE_FLOAT4 weights0=CONVERT_COMPUTE_FLOAT4(vload4(0,kernel_ptr+offset));\n"
" COMPUTE_FLOAT4 weights1=CONVERT_COMPUTE_FLOAT4(vload4(1,kernel_ptr+offset));\n"
" COMPUTE_FLOAT4 weights2=CONVERT_COMPUTE_FLOAT4(vload4(0,kernel_ptr+offset+out_c_pack));\n"
" COMPUTE_FLOAT4 weights3=CONVERT_COMPUTE_FLOAT4(vload4(1,kernel_ptr+offset+out_c_pack));\n"
" COMPUTE_FLOAT4 weights4=CONVERT_COMPUTE_FLOAT4(vload4(0,kernel_ptr+offset+out_c_pack+out_c_pack));\n"
" COMPUTE_FLOAT4 weights5=CONVERT_COMPUTE_FLOAT4(vload4(1,kernel_ptr+offset+out_c_pack+out_c_pack));\n"
" COMPUTE_FLOAT4 weights6=CONVERT_COMPUTE_FLOAT4(vload4(0,kernel_ptr+offset+out_c_pack+out_c_pack+out_c_pack));\n"
" COMPUTE_FLOAT4 weights7=CONVERT_COMPUTE_FLOAT4(vload4(1,kernel_ptr+offset+out_c_pack+out_c_pack+out_c_pack));\n"
" out0=mad(in0.x,weights0,out0);\n"
" out0=mad(in0.y,weights2,out0);\n"
" out0=mad(in0.z,weights4,out0);\n"
" out0=mad(in0.w,weights6,out0);\n"
" \n"
" out1=mad(in1.x,weights0,out1);\n"
" out1=mad(in1.y,weights2,out1);\n"
" out1=mad(in1.z,weights4,out1);\n"
" out1=mad(in1.w,weights6,out1);\n"
" \n"
" out2=mad(in2.x,weights0,out2);\n"
" out2=mad(in2.y,weights2,out2);\n"
" out2=mad(in2.z,weights4,out2);\n"
" out2=mad(in2.w,weights6,out2);\n"
" \n"
" out3=mad(in3.x,weights0,out3);\n"
" out3=mad(in3.y,weights2,out3);\n"
" out3=mad(in3.z,weights4,out3);\n"
" out3=mad(in3.w,weights6,out3);\n"
" \n"
" out4=mad(in0.x,weights1,out4);\n"
" out4=mad(in0.y,weights3,out4);\n"
" out4=mad(in0.z,weights5,out4);\n"
" out4=mad(in0.w,weights7,out4);\n"
" \n"
" out5=mad(in1.x,weights1,out5);\n"
" out5=mad(in1.y,weights3,out5);\n"
" out5=mad(in1.z,weights5,out5);\n"
" out5=mad(in1.w,weights7,out5);\n"
" \n"
" out6=mad(in2.x,weights1,out6);\n"
" out6=mad(in2.y,weights3,out6);\n"
" out6=mad(in2.z,weights5,out6);\n"
" out6=mad(in2.w,weights7,out6);\n"
" \n"
" out7=mad(in3.x,weights1,out7);\n"
" out7=mad(in3.y,weights3,out7);\n"
" out7=mad(in3.z,weights5,out7);\n"
" out7=mad(in3.w,weights7,out7);\n"
" \n"
" offset += 4*out_c_pack;\n"
" inp_offset += inp_add;\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(COMPUTE_FLOAT4)0);\n"
" out1=fmax(out1,(COMPUTE_FLOAT4)0);\n"
" out2=fmax(out2,(COMPUTE_FLOAT4)0);\n"
" out3=fmax(out3,(COMPUTE_FLOAT4)0);\n"
" \n"
" out4=fmax(out4,(COMPUTE_FLOAT4)0);\n"
" out5=fmax(out5,(COMPUTE_FLOAT4)0);\n"
" out6=fmax(out6,(COMPUTE_FLOAT4)0);\n"
" out7=fmax(out7,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out1=clamp(out1,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out2=clamp(out2,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out3=clamp(out3,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" \n"
" out4=clamp(out4,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out5=clamp(out5,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out6=clamp(out6,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out7=clamp(out7,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" const int out_offset=(((out_b_idx+out_c_idx*2*out_b)*out_h+out_h_idx)* out_w+out_w4_idx)*4;\n"
" __global FLOAT*_tempoutput=output+out_offset;\n"
" __global FLOAT*_tempoutput1=_tempoutput+4*out_h*out_w*out_b;\n"
"#ifdef BLOCK_LEAVE\n"
" const int remain=out_w-out_w4_idx;\n"
" if (remain >= 4) {\n"
" vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out0,out1,out2,out3)),0,_tempoutput);\n"
" } else if (remain == 3) {\n"
" vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out0,out1)),0,_tempoutput);\n"
" vstore4(CONVERT_FLOAT4(out2),2,_tempoutput);\n"
" } else if (remain == 2) {\n"
" vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out0,out1)),0,_tempoutput);\n"
" } else if (remain == 1) {\n"
" vstore4(CONVERT_FLOAT4(out0),0,_tempoutput);\n"
" }\n"
"#ifdef CHANNEL_LEAVE\n"
" if(out_c_idx*2+1 >= out_c_block) {\n"
" return;\n"
" }\n"
"#endif\n"
" if (remain >= 4) {\n"
" vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out4,out5,out6,out7)),0,_tempoutput1);\n"
" } else if (remain == 3) {\n"
" vstore8(CONVERT_FLOAT8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out4,out5))),0,_tempoutput1);\n"
" vstore4(CONVERT_FLOAT4(out6),2,_tempoutput1);\n"
" } else if (remain == 2) {\n"
" vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out4,out5)),0,_tempoutput1);\n"
" } else if (remain == 1) {\n"
" vstore4(CONVERT_FLOAT4(out4),0,_tempoutput1);\n"
" }\n"
"#else\n"
" vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out0,out1,out2,out3)),0,_tempoutput);\n"
"#ifdef CHANNEL_LEAVE\n"
" if(out_c_idx*2+1 >= out_c_block) {\n"
" return;\n"
" }\n"
"#endif\n"
" vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out4,out5,out6,out7)),0,_tempoutput1);\n"
"#endif\n"
"}\n"
"__kernel\n"
"void conv_2d_1x1_c8h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,\n"
" __global const FLOAT *input,\n"
" __global const FLOAT *kernel_ptr,\n"
" __global const FLOAT *bias_ptr,\n"
" __global FLOAT *output,\n"
" __private const int in_c_block,\n"
" __private const int out_h,\n"
" __private const int out_w,\n"
" __private const int out_b,\n"
" __private const int out_c_block,\n"
" __private const int out_c_pack) {\n"
" const int out_c_w_idx=get_global_id(0); //c/8 w/4\n"
" const int out_b_h_idx=get_global_id(1); //b h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int out_c_idx=out_c_w_idx/out_w_blocks;\n"
" const int out_w_idx=out_c_w_idx % out_w_blocks;\n"
" const int out_b_idx=out_b_h_idx/out_h;//equal to in_b_idx\n"
" const int out_h_idx=out_b_h_idx % out_h;//equal to in_h_idx\n"
" \n"
" const int out_w2_idx=mul24(out_w_idx,2);\n"
" COMPUTE_FLOAT4 out0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx<<1,bias_ptr));\n"
" COMPUTE_FLOAT4 out1=out0;\n"
" \n"
" COMPUTE_FLOAT4 out4=CONVERT_COMPUTE_FLOAT4(vload4((out_c_idx<<1)+1,bias_ptr));\n"
" COMPUTE_FLOAT4 out5=out4;\n"
" const int intput_width_idx0=out_w2_idx;\n"
" int inp_offset=((out_b_idx*out_h+out_h_idx)* out_w+intput_width_idx0)<<2;\n"
" int offset=out_c_idx*8;\n"
" const int inp_add=out_b*out_h*out_w*4;\n"
" for (int in_channel_block_idx=0; in_channel_block_idx<in_c_block; ++in_channel_block_idx) {\n"
" \n"
" COMPUTE_FLOAT4 in0=CONVERT_COMPUTE_FLOAT4(vload4(0,input+inp_offset));\n"
" COMPUTE_FLOAT4 in1=CONVERT_COMPUTE_FLOAT4(vload4(1,input+inp_offset));\n"
" COMPUTE_FLOAT4 weights0=CONVERT_COMPUTE_FLOAT4(vload4(0,kernel_ptr+offset));\n"
" COMPUTE_FLOAT4 weights1=CONVERT_COMPUTE_FLOAT4(vload4(1,kernel_ptr+offset));\n"
" COMPUTE_FLOAT4 weights2=CONVERT_COMPUTE_FLOAT4(vload4(0,kernel_ptr+offset+out_c_pack));\n"
" COMPUTE_FLOAT4 weights3=CONVERT_COMPUTE_FLOAT4(vload4(1,kernel_ptr+offset+out_c_pack));\n"
" COMPUTE_FLOAT4 weights4=CONVERT_COMPUTE_FLOAT4(vload4(0,kernel_ptr+offset+out_c_pack+out_c_pack));\n"
" COMPUTE_FLOAT4 weights5=CONVERT_COMPUTE_FLOAT4(vload4(1,kernel_ptr+offset+out_c_pack+out_c_pack));\n"
" COMPUTE_FLOAT4 weights6=CONVERT_COMPUTE_FLOAT4(vload4(0,kernel_ptr+offset+out_c_pack+out_c_pack+out_c_pack));\n"
" COMPUTE_FLOAT4 weights7=CONVERT_COMPUTE_FLOAT4(vload4(1,kernel_ptr+offset+out_c_pack+out_c_pack+out_c_pack));\n"
" out0=mad(in0.x,weights0,out0);\n"
" out0=mad(in0.y,weights2,out0);\n"
" out0=mad(in0.z,weights4,out0);\n"
" out0=mad(in0.w,weights6,out0);\n"
" \n"
" out1=mad(in1.x,weights0,out1);\n"
" out1=mad(in1.y,weights2,out1);\n"
" out1=mad(in1.z,weights4,out1);\n"
" out1=mad(in1.w,weights6,out1);\n"
" \n"
" out4=mad(in0.x,weights1,out4);\n"
" out4=mad(in0.y,weights3,out4);\n"
" out4=mad(in0.z,weights5,out4);\n"
" out4=mad(in0.w,weights7,out4);\n"
" \n"
" out5=mad(in1.x,weights1,out5);\n"
" out5=mad(in1.y,weights3,out5);\n"
" out5=mad(in1.z,weights5,out5);\n"
" out5=mad(in1.w,weights7,out5);\n"
" \n"
" offset += 4*out_c_pack;\n"
" inp_offset += inp_add;\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(COMPUTE_FLOAT4)0);\n"
" out1=fmax(out1,(COMPUTE_FLOAT4)0);\n"
" out4=fmax(out4,(COMPUTE_FLOAT4)0);\n"
" out5=fmax(out5,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out1=clamp(out1,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out4=clamp(out4,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out5=clamp(out5,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" const int out_offset=(((out_b_idx+out_c_idx*2*out_b)*out_h+out_h_idx)* out_w+out_w2_idx)*4;\n"
" __global FLOAT*_tempoutput=output+out_offset;\n"
" __global FLOAT*_tempoutput1=_tempoutput+4*out_h*out_w*out_b;\n"
"#ifdef BLOCK_LEAVE\n"
" const int remain=out_w-out_w2_idx;\n"
" if (remain >= 2) {\n"
" vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out0,out1)),0,_tempoutput);\n"
" } else if (remain == 1) {\n"
" vstore4(CONVERT_FLOAT4(out0),0,_tempoutput);\n"
" }\n"
"#ifdef CHANNEL_LEAVE\n"
" if(out_c_idx*2+1 >= out_c_block) {\n"
" return;\n"
" }\n"
"#endif\n"
" if (remain >= 2) {\n"
" vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out4,out5)),0,_tempoutput1);\n"
" } else if (remain == 1) {\n"
" vstore4(CONVERT_FLOAT4(out4),0,_tempoutput1);\n"
" }\n"
"#else\n"
" vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out0,out1)),0,_tempoutput);\n"
"#ifdef CHANNEL_LEAVE\n"
" if(out_c_idx*2+1 >= out_c_block) {\n"
" return;\n"
" }\n"
"#endif\n"
" vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out4,out5)),0,_tempoutput1);\n"
"#endif\n"
"}\n"
"__kernel\n"
"void conv_2d_1x1_c4h1w1(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,\n"
" __global const FLOAT *input,\n"
" __global const FLOAT *kernel_ptr,\n"
" __global const FLOAT *bias_ptr,\n"
" __global FLOAT *output,\n"
" __private const int in_c_block,\n"
" __private const int out_h,\n"
" __private const int out_w,\n"
" __private const int out_b,\n"
" __private const int out_c_block,\n"
" __private const int out_c_pack) {\n"
" const int out_c_w_idx=get_global_id(0); //c/4 w\n"
" const int out_b_h_idx=get_global_id(1); //b h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int out_c_idx=out_c_w_idx/out_w;\n"
" const int out_w_idx=out_c_w_idx % out_w;\n"
" const int out_b_idx=out_b_h_idx/out_h;//equal to in_b_idx\n"
" const int out_h_idx=out_b_h_idx % out_h;//equal to in_h_idx\n"
" COMPUTE_FLOAT4 out0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias_ptr));\n"
" const int intput_width_idx0=out_w_idx;\n"
" int offset=out_c_idx*4;\n"
" int inp_offset=((out_b_idx*out_h+out_h_idx)*out_w+intput_width_idx0)*4;\n"
" const int inp_add=out_b*out_h*out_w*4;\n"
" \n"
" for (int in_channel_block_idx=0; in_channel_block_idx<in_c_block; ++in_channel_block_idx) {\n"
" \n"
" \n"
" COMPUTE_FLOAT4 in0=CONVERT_COMPUTE_FLOAT4(vload4(0,input+inp_offset));\n"
" COMPUTE_FLOAT4 weights0=CONVERT_COMPUTE_FLOAT4(vload4(0,kernel_ptr+offset));\n"
" COMPUTE_FLOAT4 weights1=CONVERT_COMPUTE_FLOAT4(vload4(0,kernel_ptr+offset+out_c_pack));\n"
" COMPUTE_FLOAT4 weights2=CONVERT_COMPUTE_FLOAT4(vload4(0,kernel_ptr+offset+out_c_pack+out_c_pack));\n"
" COMPUTE_FLOAT4 weights3=CONVERT_COMPUTE_FLOAT4(vload4(0,kernel_ptr+offset+out_c_pack+out_c_pack+out_c_pack));\n"
" out0=mad(in0.x,weights0,out0);\n"
" out0=mad(in0.y,weights1,out0);\n"
" out0=mad(in0.z,weights2,out0);\n"
" out0=mad(in0.w,weights3,out0);\n"
" \n"
" offset += 4*out_c_pack;\n"
" inp_offset += inp_add;\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" const int out_offset=(((out_b_idx+out_c_idx*out_b)*out_h+out_h_idx)* out_w+out_w_idx)*4;\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
"}\n"
"__kernel\n"
"void conv_2d_1x1_c4h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,\n"
" __global const FLOAT *input,\n"
" __global const FLOAT *kernel_ptr,\n"
" __global const FLOAT *bias_ptr,\n"
" __global FLOAT *output,\n"
" __private const int in_c_block,\n"
" __private const int out_h,\n"
" __private const int out_w,\n"
" __private const int out_b,\n"
" __private const int out_c_block,\n"
" __private const int out_c_pack) {\n"
" const int out_c_w_idx=get_global_id(0); //c/4 w\n"
" const int out_b_h_idx=get_global_id(1); //b h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int out_c_idx=out_c_w_idx/out_w_blocks;\n"
" const int out_w_idx=out_c_w_idx % out_w_blocks;\n"
" const int out_b_idx=out_b_h_idx/out_h;//equal to in_b_idx\n"
" const int out_h_idx=out_b_h_idx % out_h;//equal to in_h_idx\n"
" const int out_w2_idx=mul24(out_w_idx,2);\n"
" COMPUTE_FLOAT4 out0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias_ptr));\n"
" COMPUTE_FLOAT4 out1=out0;\n"
" const int intput_width_idx0=out_w2_idx;\n"
" int offset=out_c_idx*4;\n"
" int inp_offset=((out_b_idx*out_h+out_h_idx)* out_w+intput_width_idx0)*4;\n"
" const int inp_add=out_b*out_h*out_w*4;\n"
" \n"
" for (int in_channel_block_idx=0; in_channel_block_idx<in_c_block; ++in_channel_block_idx) {\n"
" \n"
" COMPUTE_FLOAT4 in0=CONVERT_COMPUTE_FLOAT4(vload4(0,input+inp_offset));\n"
" COMPUTE_FLOAT4 in1=CONVERT_COMPUTE_FLOAT4(vload4(1,input+inp_offset));\n"
" COMPUTE_FLOAT4 weights0=CONVERT_COMPUTE_FLOAT4(vload4(0,kernel_ptr+offset));\n"
" COMPUTE_FLOAT4 weights1=CONVERT_COMPUTE_FLOAT4(vload4(0,kernel_ptr+offset+out_c_pack));\n"
" COMPUTE_FLOAT4 weights2=CONVERT_COMPUTE_FLOAT4(vload4(0,kernel_ptr+offset+out_c_pack+out_c_pack));\n"
" COMPUTE_FLOAT4 weights3=CONVERT_COMPUTE_FLOAT4(vload4(0,kernel_ptr+offset+out_c_pack+out_c_pack+out_c_pack));\n"
" out0=mad(in0.x,weights0,out0);\n"
" out0=mad(in0.y,weights1,out0);\n"
" out0=mad(in0.z,weights2,out0);\n"
" out0=mad(in0.w,weights3,out0);\n"
" \n"
" out1=mad(in1.x,weights0,out1);\n"
" out1=mad(in1.y,weights1,out1);\n"
" out1=mad(in1.z,weights2,out1);\n"
" out1=mad(in1.w,weights3,out1);\n"
" \n"
" offset += 4*out_c_pack;\n"
" inp_offset += inp_add;\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(COMPUTE_FLOAT4)0);\n"
" out1=fmax(out1,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out1=clamp(out1,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" const int out_offset=(((out_b_idx+out_c_idx*out_b)*out_h+out_h_idx)* out_w+out_w2_idx)*4;\n"
"#ifdef BLOCK_LEAVE\n"
" const int remain=out_w-out_w2_idx;\n"
" if (remain >= 2) {\n"
" vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out0,out1)),0,output+out_offset);\n"
" } else if (remain == 1) {\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" }\n"
"#else\n"
" vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out0,out1)),0,output+out_offset);\n"
"#endif\n"
"}\n"
"__kernel\n"
"void conv_2d_c4h1w1(GLOBAL_SIZE_2_DIMS\n"
" __global const FLOAT *input,\n"
" __global const FLOAT *weight,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT *output,\n"
" __private const int2 in_hw,\n"
" __private const int inChannel,\n"
" __private const int in_c_blocks,\n"
" __private const int batch,\n"
" __private const int2 out_hw,\n"
" __private const int2 filter_hw,\n"
" __private const int2 stride_hw,\n"
" __private const int2 pad_hw,\n"
" __private const int2 dilate_hw,\n"
" __private const int out_w_blocks,\n"
" __private const int out_c_blocks,\n"
" __private const int out_h_blocks) {\n"
" const int out_c_w_idx=get_global_id(0); //c/4 w\n"
" const int out_b_h_idx=get_global_id(1); //b h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int out_c_idx=out_c_w_idx/out_hw.y;\n"
" const int out_w_idx=out_c_w_idx % out_hw.y;\n"
" const int out_b_idx=out_b_h_idx/out_hw.x;//equal to in_b_idx\n"
" const int out_h_idx=out_b_h_idx % out_hw.x;\n"
" \n"
" COMPUTE_FLOAT4 out0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias));\n"
" \n"
" const int in_w_idx_base=mad24(out_w_idx,stride_hw.y,-pad_hw.y);\n"
" const int in_h_idx_base=mad24(out_h_idx,stride_hw.x,-pad_hw.x);\n"
" \n"
" const int kw_start=select(0,(-in_w_idx_base+dilate_hw.y-1)/dilate_hw.y,in_w_idx_base<0);\n"
" const int kh_start=select(0,(-in_h_idx_base+dilate_hw.x-1)/dilate_hw.x,in_h_idx_base<0);\n"
" const int in_w_idx_start=mad24(kw_start,dilate_hw.y,in_w_idx_base);\n"
" const int in_w_idx_end=min(mad24(filter_hw.y,dilate_hw.y,in_w_idx_base),in_hw.y);\n"
" \n"
" const int in_h_idx_start=mad24(kh_start,dilate_hw.x,in_h_idx_base);\n"
" const int in_h_idx_end=min(mad24(filter_hw.x,dilate_hw.x,in_h_idx_base),in_hw.x);\n"
" \n"
" const int weight_oc_offset=out_c_blocks*filter_hw.x*filter_hw.y*4;\n"
" for(ushort in_c_idx=0; in_c_idx<in_c_blocks; in_c_idx++) {\n"
" //weights NC4HW4 [1,4*icC4,ocC4*kh*kw,1] xic4\n"
" //index: [0,4*in_c_idx,out_c_idx*kh*kw+kh_start*kw+kw_start,0]\n"
" int weight_offset=((((4*in_c_idx+0)* out_c_blocks+out_c_idx) *filter_hw.x+kh_start)*filter_hw.y+kw_start)*4;\n"
" for(int iy=in_h_idx_start; iy<in_h_idx_end; iy += dilate_hw.x) {\n"
" for(int ix=in_w_idx_start; ix<in_w_idx_end; ix += dilate_hw.y) {\n"
" int inp_offset=(((out_b_idx+in_c_idx*batch)*in_hw.x+iy)*in_hw.y+ix)*4;\n"
" COMPUTE_FLOAT4 in0=CONVERT_COMPUTE_FLOAT4(vload4(0,input+inp_offset));\n"
" \n"
" const int filter_w_inc=(ix-in_w_idx_start)/dilate_hw.y;\n"
" COMPUTE_FLOAT4 weight0=CONVERT_COMPUTE_FLOAT4(vload4(filter_w_inc,weight+weight_offset));\n"
" COMPUTE_FLOAT4 weight1=CONVERT_COMPUTE_FLOAT4(vload4(filter_w_inc,weight+weight_offset+weight_oc_offset));\n"
" COMPUTE_FLOAT4 weight2=CONVERT_COMPUTE_FLOAT4(vload4(filter_w_inc,weight+weight_offset+weight_oc_offset*2));\n"
" COMPUTE_FLOAT4 weight3=CONVERT_COMPUTE_FLOAT4(vload4(filter_w_inc,weight+weight_offset+weight_oc_offset*3));\n"
" out0=mad(in0.x,weight0,out0);\n"
" out0=mad(in0.y,weight1,out0);\n"
" out0=mad(in0.z,weight2,out0);\n"
" out0=mad(in0.w,weight3,out0);\n"
" }\n"
" weight_offset += 4*filter_hw.y;\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" const int out_offset=(((out_b_idx+out_c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" \n"
"}\n"
"__kernel\n"
"void conv_2d_c4h1w2(GLOBAL_SIZE_2_DIMS\n"
" __global const FLOAT *input,\n"
" __global const FLOAT *weight,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT *output,\n"
" __private const int2 in_hw,\n"
" __private const int inChannel,\n"
" __private const int in_c_blocks,\n"
" __private const int batch,\n"
" __private const int2 out_hw,\n"
" __private const int2 filter_hw,\n"
" __private const int2 stride_hw,\n"
" __private const int2 pad_hw,\n"
" __private const int2 dilate_hw,\n"
" __private const int out_w_blocks,//generate width's num\n"
" __private const int out_c_blocks,\n"
" __private const int out_h_blocks) {\n"
" const int out_c_w_idx=get_global_id(0); //c/4 w\n"
" const int out_b_h_idx=get_global_id(1); //b h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int out_c_idx=out_c_w_idx/out_w_blocks;\n"
" const int out_w_idx=(out_c_w_idx % out_w_blocks) << 1;\n"
" const int out_b_idx=out_b_h_idx/out_hw.x;//equal to in_b_idx\n"
" const int out_h_idx=out_b_h_idx % out_hw.x;\n"
" \n"
" COMPUTE_FLOAT4 out0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias));\n"
" COMPUTE_FLOAT4 out1=out0;\n"
" \n"
" const int in_w0_idx_base=mad24(out_w_idx,stride_hw.y,-pad_hw.y);\n"
" const int in_w1_idx_base=in_w0_idx_base+stride_hw.y;\n"
" const int in_h_idx_base=mad24(out_h_idx,stride_hw.x,-pad_hw.x);\n"
" \n"
" const int kh_start=select(0,(-in_h_idx_base+dilate_hw.x-1)/dilate_hw.x,in_h_idx_base<0);\n"
" const int in_h_idx_start=mad24(kh_start,dilate_hw.x,in_h_idx_base);\n"
" const int in_h_idx_end=min(mad24(filter_hw.x,dilate_hw.x,in_h_idx_base),in_hw.x);\n"
" \n"
" const int weight_oc_offset=out_c_blocks*filter_hw.x*filter_hw.y*4;\n"
" for(ushort in_c_idx=0; in_c_idx<in_c_blocks; in_c_idx++) {\n"
" //weights NC4HW4 [1,4*icC4,ocC4*kh*kw,1] xic4\n"
" //index: [0,4*in_c_idx,out_c_idx*kh*kw+kh_start*kw+kw_start,0]\n"
" int weight_offset=((((4*in_c_idx+0)* out_c_blocks+out_c_idx) *filter_hw.x+kh_start)*filter_hw.y+0)*4;\n"
" for(int iy=in_h_idx_start; iy<in_h_idx_end; iy += dilate_hw.x) {\n"
" const int inp_offset_base=(((out_b_idx+in_c_idx*batch)*in_hw.x+iy)*in_hw.y+0)*4;\n"
" for(int fw=0; fw<filter_hw.y; fw++) {\n"
" const int in_w0_idx=fw*dilate_hw.y+in_w0_idx_base;\n"
" const int in_w1_idx=fw*dilate_hw.y+in_w1_idx_base;\n"
" COMPUTE_FLOAT4 in0=(in_w0_idx<0 || in_w0_idx >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_w0_idx,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in1=(in_w1_idx<0 || in_w1_idx >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_w1_idx,input+inp_offset_base));\n"
" \n"
" COMPUTE_FLOAT4 weight0=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset));\n"
" COMPUTE_FLOAT4 weight1=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset));\n"
" COMPUTE_FLOAT4 weight2=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset*2));\n"
" COMPUTE_FLOAT4 weight3=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset*3));\n"
" out0=mad(in0.x,weight0,out0);\n"
" out0=mad(in0.y,weight1,out0);\n"
" out0=mad(in0.z,weight2,out0);\n"
" out0=mad(in0.w,weight3,out0);\n"
" \n"
" out1=mad(in1.x,weight0,out1);\n"
" out1=mad(in1.y,weight1,out1);\n"
" out1=mad(in1.z,weight2,out1);\n"
" out1=mad(in1.w,weight3,out1);\n"
" \n"
" weight_offset += 4;\n"
" }\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(COMPUTE_FLOAT4)0);\n"
" out1=fmax(out1,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out1=clamp(out1,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" const int out_offset=(((out_b_idx+out_c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
"#ifdef BLOCK_LEAVE\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" if(out_w_idx+1 >= out_hw.y) return;\n"
" vstore4(CONVERT_FLOAT4(out1),1,output+out_offset);\n"
"#else\n"
" vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out0,out1)),0,output+out_offset);\n"
"#endif\n"
"}\n"
"__kernel\n"
"void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS\n"
" __global const FLOAT *input,\n"
" __global const FLOAT *weight,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT *output,\n"
" __private const int2 in_hw,\n"
" __private const int inChannel,\n"
" __private const int in_c_blocks,\n"
" __private const int batch,\n"
" __private const int2 out_hw,\n"
" __private const int2 filter_hw,\n"
" __private const int2 stride_hw,\n"
" __private const int2 pad_hw,\n"
" __private const int2 dilate_hw,\n"
" __private const int out_w_blocks,\n"
" __private const int out_c_blocks,\n"
" __private const int out_h_blocks) {\n"
" const int out_c_w_idx=get_global_id(0); //c/4 w\n"
" const int out_b_h_idx=get_global_id(1); //b h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int out_c_idx=out_c_w_idx/out_w_blocks;\n"
" const int out_w_idx=(out_c_w_idx % out_w_blocks) << 2;\n"
" const int out_b_idx=out_b_h_idx/out_hw.x;//equal to in_b_idx\n"
" const int out_h_idx=out_b_h_idx % out_hw.x;\n"
" COMPUTE_FLOAT4 out0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias));\n"
" COMPUTE_FLOAT4 out1=out0;\n"
" COMPUTE_FLOAT4 out2=out0;\n"
" COMPUTE_FLOAT4 out3=out0;\n"
" const int in_w0_idx_base=mad24(out_w_idx,stride_hw.y,-pad_hw.y);\n"
" const int in_w1_idx_base=in_w0_idx_base+stride_hw.y;\n"
" const int in_w2_idx_base=in_w1_idx_base+stride_hw.y;\n"
" const int in_w3_idx_base=in_w2_idx_base+stride_hw.y;\n"
" const int in_h_idx_base=mad24(out_h_idx,stride_hw.x,-pad_hw.x);\n"
" \n"
" const int kh_start=select(0,(-in_h_idx_base+dilate_hw.x-1)/dilate_hw.x,in_h_idx_base<0);\n"
" const int in_h_idx_start=mad24(kh_start,dilate_hw.x,in_h_idx_base);\n"
" const int in_h_idx_end=min(mad24(filter_hw.x,dilate_hw.x,in_h_idx_base),in_hw.x);\n"
" \n"
" const int weight_oc_offset=out_c_blocks*filter_hw.x*filter_hw.y*4;\n"
" for(ushort in_c_idx=0; in_c_idx<in_c_blocks; in_c_idx++) {\n"
" //weights NC4HW4 [1,4*icC4,ocC4*kh*kw,1] xic4\n"
" //index: [0,4*in_c_idx,out_c_idx*kh*kw+kh_start*kw+kw_start,0]\n"
" int weight_offset=((((4*in_c_idx+0)* out_c_blocks+out_c_idx) *filter_hw.x+kh_start)*filter_hw.y+0)*4;\n"
" for(int iy=in_h_idx_start; iy<in_h_idx_end; iy += dilate_hw.x) {\n"
" const int inp_offset_base=(((out_b_idx+in_c_idx*batch)*in_hw.x+iy)*in_hw.y+0)*4;\n"
" for(int fw=0; fw<filter_hw.y; fw++) {\n"
" const int in_w0_idx=fw*dilate_hw.y+in_w0_idx_base;\n"
" const int in_w1_idx=fw*dilate_hw.y+in_w1_idx_base;\n"
" const int in_w2_idx=fw*dilate_hw.y+in_w2_idx_base;\n"
" const int in_w3_idx=fw*dilate_hw.y+in_w3_idx_base;\n"
" COMPUTE_FLOAT4 in0=(in_w0_idx<0 || in_w0_idx >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_w0_idx,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in1=(in_w1_idx<0 || in_w1_idx >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_w1_idx,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in2=(in_w2_idx<0 || in_w2_idx >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_w2_idx,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in3=(in_w3_idx<0 || in_w3_idx >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_w3_idx,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 weight0=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset));\n"
" COMPUTE_FLOAT4 weight1=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset));\n"
" COMPUTE_FLOAT4 weight2=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset*2));\n"
" COMPUTE_FLOAT4 weight3=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset*3));\n"
" out0=mad(in0.x,weight0,out0);\n"
" out0=mad(in0.y,weight1,out0);\n"
" out0=mad(in0.z,weight2,out0);\n"
" out0=mad(in0.w,weight3,out0);\n"
" \n"
" out1=mad(in1.x,weight0,out1);\n"
" out1=mad(in1.y,weight1,out1);\n"
" out1=mad(in1.z,weight2,out1);\n"
" out1=mad(in1.w,weight3,out1);\n"
" \n"
" out2=mad(in2.x,weight0,out2);\n"
" out2=mad(in2.y,weight1,out2);\n"
" out2=mad(in2.z,weight2,out2);\n"
" out2=mad(in2.w,weight3,out2);\n"
" \n"
" out3=mad(in3.x,weight0,out3);\n"
" out3=mad(in3.y,weight1,out3);\n"
" out3=mad(in3.z,weight2,out3);\n"
" out3=mad(in3.w,weight3,out3);\n"
" \n"
" weight_offset += 4;\n"
" }\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(COMPUTE_FLOAT4)0);\n"
" out1=fmax(out1,(COMPUTE_FLOAT4)0);\n"
" out2=fmax(out2,(COMPUTE_FLOAT4)0);\n"
" out3=fmax(out3,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out1=clamp(out1,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out2=clamp(out2,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out3=clamp(out3,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" const int out_offset=(((out_b_idx+out_c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
"#ifdef BLOCK_LEAVE\n"
" const int remain=out_hw.y-out_w_idx;\n"
" if (remain >= 4) {\n"
" vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out0,out1,out2,out3)),0,output+out_offset);\n"
" }else if(remain == 3){\n"
" vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out0,out1)),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out2),2,output+out_offset);\n"
" }else if(remain == 2){\n"
" vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out0,out1)),0,output+out_offset);\n"
" }else if(remain == 1){\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" }\n"
"#else\n"
" vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out0,out1,out2,out3)),0,output+out_offset);\n"
"#endif\n"
"}\n"
"__kernel\n"
"void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS\n"
" __global const FLOAT *input,\n"
" __global const FLOAT *weight,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT *output,\n"
" __private const int2 in_hw,\n"
" __private const int inChannel,\n"
" __private const int in_c_blocks,\n"
" __private const int batch,\n"
" __private const int2 out_hw,\n"
" __private const int2 filter_hw,\n"
" __private const int2 stride_hw,\n"
" __private const int2 pad_hw,\n"
" __private const int2 dilate_hw,\n"
" __private const int out_w_blocks,\n"
" __private const int out_c_blocks,\n"
" __private const int out_h_blocks) {\n"
" const int out_c_w_idx=get_global_id(0); //c/4 w\n"
" const int out_b_h_idx=get_global_id(1); //b h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int out_c_idx=out_c_w_idx/out_w_blocks;\n"
" const int out_w_idx=out_c_w_idx % out_w_blocks;\n"
" const int out_b_idx=out_b_h_idx/out_h_blocks;//equal to in_b_idx\n"
" const int out_h_idx=(out_b_h_idx % out_h_blocks) << 2;\n"
" \n"
" COMPUTE_FLOAT4 out0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias));\n"
" COMPUTE_FLOAT4 out1=out0;\n"
" COMPUTE_FLOAT4 out2=out0;\n"
" COMPUTE_FLOAT4 out3=out0;\n"
" const int in_w_idx_base=mad24(out_w_idx,stride_hw.y,-pad_hw.y);\n"
" const int in_h0_idx_base=mad24(out_h_idx,stride_hw.x,-pad_hw.x);\n"
" const int in_h1_idx_base=in_h0_idx_base+stride_hw.x;\n"
" const int in_h2_idx_base=in_h1_idx_base+stride_hw.x;\n"
" const int in_h3_idx_base=in_h2_idx_base+stride_hw.x;\n"
" \n"
" const int kw_start=select(0,(-in_w_idx_base+dilate_hw.y-1)/dilate_hw.y,in_w_idx_base<0);\n"
" const int in_w_idx_start=mad24(kw_start,dilate_hw.y,in_w_idx_base);\n"
" const int in_w_idx_end=min(mad24(filter_hw.y,dilate_hw.y,in_w_idx_base),in_hw.y);\n"
" \n"
" const int weight_oc_offset=out_c_blocks*filter_hw.x*filter_hw.y*4;\n"
" const int in_hw_size=in_hw.x*in_hw.y;\n"
" for(ushort in_c_idx=0; in_c_idx<in_c_blocks; in_c_idx++) {\n"
" //weights NC4HW4 [1,4*icC4,ocC4*kh*kw,1] xic4\n"
" //index: [0,4*in_c_idx,out_c_idx*kh*kw+kh_start*kw+kw_start,0]\n"
" const int inp_offset_base=(out_b_idx+in_c_idx*batch)*in_hw.x*in_hw.y*4;\n"
" for(int iy=0; iy<filter_hw.x; iy++) {\n"
" int weight_offset=((((4*in_c_idx+0)* out_c_blocks+out_c_idx) *filter_hw.x+iy)*filter_hw.y+kw_start)*4;\n"
" const int in_h0_idx=(iy*dilate_hw.x+in_h0_idx_base)*in_hw.y;\n"
" const int in_h1_idx=(iy*dilate_hw.x+in_h1_idx_base)*in_hw.y;\n"
" const int in_h2_idx=(iy*dilate_hw.x+in_h2_idx_base)*in_hw.y;\n"
" const int in_h3_idx=(iy*dilate_hw.x+in_h3_idx_base)*in_hw.y;\n"
" for(int fw=in_w_idx_start; fw<in_w_idx_end; fw += dilate_hw.y) {\n"
" COMPUTE_FLOAT4 in0=(in_h0_idx<0 || in_h0_idx >= in_hw_size) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_h0_idx+fw,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in1=(in_h1_idx<0 || in_h1_idx >= in_hw_size) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_h1_idx+fw,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in2=(in_h2_idx<0 || in_h2_idx >= in_hw_size) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_h2_idx+fw,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in3=(in_h3_idx<0 || in_h3_idx >= in_hw_size) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_h3_idx+fw,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 weight0=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset));\n"
" COMPUTE_FLOAT4 weight1=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset));\n"
" COMPUTE_FLOAT4 weight2=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset*2));\n"
" COMPUTE_FLOAT4 weight3=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset*3));\n"
" \n"
" out0=mad(in0.x,weight0,out0);\n"
" out0=mad(in0.y,weight1,out0);\n"
" out0=mad(in0.z,weight2,out0);\n"
" out0=mad(in0.w,weight3,out0);\n"
" \n"
" out1=mad(in1.x,weight0,out1);\n"
" out1=mad(in1.y,weight1,out1);\n"
" out1=mad(in1.z,weight2,out1);\n"
" out1=mad(in1.w,weight3,out1);\n"
" \n"
" out2=mad(in2.x,weight0,out2);\n"
" out2=mad(in2.y,weight1,out2);\n"
" out2=mad(in2.z,weight2,out2);\n"
" out2=mad(in2.w,weight3,out2);\n"
" \n"
" out3=mad(in3.x,weight0,out3);\n"
" out3=mad(in3.y,weight1,out3);\n"
" out3=mad(in3.z,weight2,out3);\n"
" out3=mad(in3.w,weight3,out3);\n"
" \n"
" weight_offset += 4;\n"
" }\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(COMPUTE_FLOAT4)0);\n"
" out1=fmax(out1,(COMPUTE_FLOAT4)0);\n"
" out2=fmax(out2,(COMPUTE_FLOAT4)0);\n"
" out3=fmax(out3,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out1=clamp(out1,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out2=clamp(out2,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out3=clamp(out3,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" const int out_offset=(((out_b_idx+out_c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
"#ifdef BLOCK_LEAVE\n"
" const int remain=out_hw.x-out_h_idx;\n"
" if(remain >= 4){\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out1),out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out2),2*out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out3),3*out_hw.y,output+out_offset);\n"
" }else if(remain == 3){\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out1),out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out2),2*out_hw.y,output+out_offset);\n"
" }else if(remain == 2){\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out1),out_hw.y,output+out_offset);\n"
" }else if(remain == 1){\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" }\n"
"#else\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out1),out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out2),2*out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out3),3*out_hw.y,output+out_offset);\n"
"#endif\n"
"}\n"
"__kernel\n"
"void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS\n"
" __global const FLOAT *input,\n"
" __global const FLOAT *weight,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT *output,\n"
" __private const int2 in_hw,\n"
" __private const int inChannel,\n"
" __private const int in_c_blocks,\n"
" __private const int batch,\n"
" __private const int2 out_hw,\n"
" __private const int2 filter_hw,\n"
" __private const int2 stride_hw,\n"
" __private const int2 pad_hw,\n"
" __private const int2 dilate_hw,\n"
" __private const int out_w_blocks,\n"
" __private const int out_c_blocks,\n"
" __private const int out_h_blocks) {\n"
" const int out_c_w_idx=get_global_id(0); //c/4 w\n"
" const int out_b_h_idx=get_global_id(1); //b h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int out_c_idx=(out_c_w_idx/out_w_blocks) << 1;\n"
" const int out_w_idx=out_c_w_idx % out_w_blocks;\n"
" const int out_b_idx=out_b_h_idx/out_h_blocks;//equal to in_b_idx\n"
" const int out_h_idx=(out_b_h_idx % out_h_blocks) << 2;\n"
" \n"
" COMPUTE_FLOAT4 out0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias));\n"
" COMPUTE_FLOAT4 out1=out0;\n"
" COMPUTE_FLOAT4 out2=out0;\n"
" COMPUTE_FLOAT4 out3=out0;\n"
" COMPUTE_FLOAT4 out4=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx+1,bias));\n"
" COMPUTE_FLOAT4 out5=out4;\n"
" COMPUTE_FLOAT4 out6=out4;\n"
" COMPUTE_FLOAT4 out7=out4;\n"
" const int in_w_idx_base=mad24(out_w_idx,stride_hw.y,-pad_hw.y);\n"
" const int in_h0_idx_base=mad24(out_h_idx,stride_hw.x,-pad_hw.x);\n"
" const int in_h1_idx_base=in_h0_idx_base+stride_hw.x;\n"
" const int in_h2_idx_base=in_h1_idx_base+stride_hw.x;\n"
" const int in_h3_idx_base=in_h2_idx_base+stride_hw.x;\n"
" \n"
" const int kw_start=select(0,(-in_w_idx_base+dilate_hw.y-1)/dilate_hw.y,in_w_idx_base<0);\n"
" const int in_w_idx_start=mad24(kw_start,dilate_hw.y,in_w_idx_base);\n"
" const int in_w_idx_end=min(mad24(filter_hw.y,dilate_hw.y,in_w_idx_base),in_hw.y);\n"
" \n"
" const int weight_oc_offset=filter_hw.x*filter_hw.y*4;\n"
" const int weight_ic_offset=out_c_blocks*weight_oc_offset;\n"
" const int in_hw_size=in_hw.x*in_hw.y;\n"
" for(ushort in_c_idx=0; in_c_idx<in_c_blocks; in_c_idx++) {\n"
" //weights NC4HW4 [1,4*icC4,ocC4*kh*kw,1] xic4\n"
" //index: [0,4*in_c_idx,out_c_idx*kh*kw+kh_start*kw+kw_start,0]\n"
" const int inp_offset_base=(out_b_idx+in_c_idx*batch)*in_hw.x*in_hw.y*4;\n"
" for(int iy=0; iy<filter_hw.x; iy++) {\n"
" int weight_offset=((((4*in_c_idx+0)* out_c_blocks+out_c_idx) *filter_hw.x+iy)*filter_hw.y+kw_start)*4;\n"
" const int in_h0_idx=(iy*dilate_hw.x+in_h0_idx_base)*in_hw.y;\n"
" const int in_h1_idx=(iy*dilate_hw.x+in_h1_idx_base)*in_hw.y;\n"
" const int in_h2_idx=(iy*dilate_hw.x+in_h2_idx_base)*in_hw.y;\n"
" const int in_h3_idx=(iy*dilate_hw.x+in_h3_idx_base)*in_hw.y;\n"
" for(int fw=in_w_idx_start; fw<in_w_idx_end; fw += dilate_hw.y) {\n"
" COMPUTE_FLOAT4 in0=(in_h0_idx<0 || in_h0_idx >= in_hw_size) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_h0_idx+fw,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in1=(in_h1_idx<0 || in_h1_idx >= in_hw_size) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_h1_idx+fw,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in2=(in_h2_idx<0 || in_h2_idx >= in_hw_size) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_h2_idx+fw,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in3=(in_h3_idx<0 || in_h3_idx >= in_hw_size) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_h3_idx+fw,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 weight0=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset));\n"
" COMPUTE_FLOAT4 weight1=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_ic_offset));\n"
" COMPUTE_FLOAT4 weight2=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_ic_offset*2));\n"
" COMPUTE_FLOAT4 weight3=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_ic_offset*3));\n"
" \n"
" out0=mad(in0.x,weight0,out0);\n"
" out0=mad(in0.y,weight1,out0);\n"
" out0=mad(in0.z,weight2,out0);\n"
" out0=mad(in0.w,weight3,out0);\n"
" \n"
" out1=mad(in1.x,weight0,out1);\n"
" out1=mad(in1.y,weight1,out1);\n"
" out1=mad(in1.z,weight2,out1);\n"
" out1=mad(in1.w,weight3,out1);\n"
" \n"
" out2=mad(in2.x,weight0,out2);\n"
" out2=mad(in2.y,weight1,out2);\n"
" out2=mad(in2.z,weight2,out2);\n"
" out2=mad(in2.w,weight3,out2);\n"
" \n"
" out3=mad(in3.x,weight0,out3);\n"
" out3=mad(in3.y,weight1,out3);\n"
" out3=mad(in3.z,weight2,out3);\n"
" out3=mad(in3.w,weight3,out3);\n"
" weight0=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset));\n"
" weight1=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset));\n"
" weight2=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*2));\n"
" weight3=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*3));\n"
" out4=mad(in0.x,weight0,out4);\n"
" out4=mad(in0.y,weight1,out4);\n"
" out4=mad(in0.z,weight2,out4);\n"
" out4=mad(in0.w,weight3,out4);\n"
" \n"
" out5=mad(in1.x,weight0,out5);\n"
" out5=mad(in1.y,weight1,out5);\n"
" out5=mad(in1.z,weight2,out5);\n"
" out5=mad(in1.w,weight3,out5);\n"
" \n"
" out6=mad(in2.x,weight0,out6);\n"
" out6=mad(in2.y,weight1,out6);\n"
" out6=mad(in2.z,weight2,out6);\n"
" out6=mad(in2.w,weight3,out6);\n"
" \n"
" out7=mad(in3.x,weight0,out7);\n"
" out7=mad(in3.y,weight1,out7);\n"
" out7=mad(in3.z,weight2,out7);\n"
" out7=mad(in3.w,weight3,out7);\n"
" \n"
" weight_offset += 4;\n"
" }\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(COMPUTE_FLOAT4)0);\n"
" out1=fmax(out1,(COMPUTE_FLOAT4)0);\n"
" out2=fmax(out2,(COMPUTE_FLOAT4)0);\n"
" out3=fmax(out3,(COMPUTE_FLOAT4)0);\n"
" out4=fmax(out4,(COMPUTE_FLOAT4)0);\n"
" out5=fmax(out5,(COMPUTE_FLOAT4)0);\n"
" out6=fmax(out6,(COMPUTE_FLOAT4)0);\n"
" out7=fmax(out7,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out1=clamp(out1,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out2=clamp(out2,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out3=clamp(out3,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out4=clamp(out4,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out5=clamp(out5,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out6=clamp(out6,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out7=clamp(out7,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" int out_offset=(((out_b_idx+out_c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
"#ifdef BLOCK_LEAVE\n"
" const int remain=out_hw.x-out_h_idx;\n"
" if(remain >= 4){\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out1),out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out2),2*out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out3),3*out_hw.y,output+out_offset);\n"
" }else if(remain == 3){\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out1),out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out2),2*out_hw.y,output+out_offset);\n"
" }else if(remain == 2){\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out1),out_hw.y,output+out_offset);\n"
" }else if(remain == 1){\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" }\n"
" #ifdef CHANNEL_LEAVE\n"
" if(out_c_idx+1 >= out_c_blocks){\n"
" return;\n"
" }\n"
" #endif\n"
" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
" if(remain >= 4){\n"
" vstore4(CONVERT_FLOAT4(out4),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out5),out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out6),2*out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out7),3*out_hw.y,output+out_offset);\n"
" }else if(remain == 3){\n"
" vstore4(CONVERT_FLOAT4(out4),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out5),out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out6),2*out_hw.y,output+out_offset);\n"
" }else if(remain == 2){\n"
" vstore4(CONVERT_FLOAT4(out4),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out5),out_hw.y,output+out_offset);\n"
" }else if(remain == 1){\n"
" vstore4(CONVERT_FLOAT4(out4),0,output+out_offset);\n"
" }\n"
"#else\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out1),out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out2),2*out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out3),3*out_hw.y,output+out_offset);\n"
" #ifdef CHANNEL_LEAVE\n"
" if(out_c_idx+1 >= out_c_blocks){\n"
" return;\n"
" }\n"
" #endif\n"
" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
" vstore4(CONVERT_FLOAT4(out4),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out5),out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out6),2*out_hw.y,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out7),3*out_hw.y,output+out_offset);\n"
"#endif\n"
"}\n"
"__kernel\n"
"void conv_2d_c8h2w1(GLOBAL_SIZE_2_DIMS\n"
" __global const FLOAT *input,\n"
" __global const FLOAT *weight,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT *output,\n"
" __private const int2 in_hw,\n"
" __private const int inChannel,\n"
" __private const int in_c_blocks,\n"
" __private const int batch,\n"
" __private const int2 out_hw,\n"
" __private const int2 filter_hw,\n"
" __private const int2 stride_hw,\n"
" __private const int2 pad_hw,\n"
" __private const int2 dilate_hw,\n"
" __private const int out_w_blocks,\n"
" __private const int out_c_blocks,\n"
" __private const int out_h_blocks) {\n"
" const int out_c_w_idx=get_global_id(0); //c/4 w\n"
" const int out_b_h_idx=get_global_id(1); //b h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int out_c_idx=(out_c_w_idx/out_w_blocks) << 1;\n"
" const int out_w_idx=out_c_w_idx % out_w_blocks;\n"
" const int out_b_idx=out_b_h_idx/out_h_blocks;//equal to in_b_idx\n"
" const int out_h_idx=(out_b_h_idx % out_h_blocks) << 1;\n"
" \n"
" COMPUTE_FLOAT4 out0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias));\n"
" COMPUTE_FLOAT4 out1=out0;\n"
" COMPUTE_FLOAT4 out2=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx+1,bias));\n"
" COMPUTE_FLOAT4 out3=out2;\n"
" const int in_w_idx_base=mad24(out_w_idx,stride_hw.y,-pad_hw.y);\n"
" const int in_h0_idx_base=mad24(out_h_idx,stride_hw.x,-pad_hw.x);\n"
" const int in_h1_idx_base=in_h0_idx_base+stride_hw.x;\n"
" \n"
" const int kw_start=select(0,(-in_w_idx_base+dilate_hw.y-1)/dilate_hw.y,in_w_idx_base<0);\n"
" const int in_w_idx_start=mad24(kw_start,dilate_hw.y,in_w_idx_base);\n"
" const int in_w_idx_end=min(mad24(filter_hw.y,dilate_hw.y,in_w_idx_base),in_hw.y);\n"
" \n"
" const int weight_oc_offset=filter_hw.x*filter_hw.y*4;\n"
" const int weight_ic_offset=out_c_blocks*weight_oc_offset;\n"
" const int in_hw_size=in_hw.x*in_hw.y;\n"
" // weight: [ic/4,oc,4],loop: ic/4\n"
" for(ushort in_c_idx=0; in_c_idx<in_c_blocks; in_c_idx++) {\n"
" //weights NC4HW4 [1,4*icC4,ocC4*kh*kw,1] xic4\n"
" //index: [0,4*in_c_idx,out_c_idx*kh*kw+kh_start*kw+kw_start,0]\n"
" const int inp_offset_base=(out_b_idx+in_c_idx*batch)*in_hw.x*in_hw.y*4;\n"
" for(int iy=0; iy<filter_hw.x; iy++) {\n"
" int weight_offset=((((4*in_c_idx+0)* out_c_blocks+out_c_idx) *filter_hw.x+iy)*filter_hw.y+kw_start)*4;\n"
" const int in_h0_idx=(iy*dilate_hw.x+in_h0_idx_base)*in_hw.y;\n"
" const int in_h1_idx=(iy*dilate_hw.x+in_h1_idx_base)*in_hw.y;\n"
" for(int fw=in_w_idx_start; fw<in_w_idx_end; fw += dilate_hw.y) {\n"
" COMPUTE_FLOAT4 in0=(in_h0_idx<0 || in_h0_idx >= in_hw_size) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_h0_idx+fw,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in1=(in_h1_idx<0 || in_h1_idx >= in_hw_size) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_h1_idx+fw,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 weight0=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset));\n"
" COMPUTE_FLOAT4 weight1=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_ic_offset));\n"
" COMPUTE_FLOAT4 weight2=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_ic_offset*2));\n"
" COMPUTE_FLOAT4 weight3=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_ic_offset*3));\n"
" \n"
" out0=mad(in0.x,weight0,out0);\n"
" out0=mad(in0.y,weight1,out0);\n"
" out0=mad(in0.z,weight2,out0);\n"
" out0=mad(in0.w,weight3,out0);\n"
" \n"
" out1=mad(in1.x,weight0,out1);\n"
" out1=mad(in1.y,weight1,out1);\n"
" out1=mad(in1.z,weight2,out1);\n"
" out1=mad(in1.w,weight3,out1);\n"
" \n"
" weight0=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset));\n"
" weight1=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset));\n"
" weight2=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*2));\n"
" weight3=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*3));\n"
" \n"
" out2=mad(in0.x,weight0,out2);\n"
" out2=mad(in0.y,weight1,out2);\n"
" out2=mad(in0.z,weight2,out2);\n"
" out2=mad(in0.w,weight3,out2);\n"
" \n"
" out3=mad(in1.x,weight0,out3);\n"
" out3=mad(in1.y,weight1,out3);\n"
" out3=mad(in1.z,weight2,out3);\n"
" out3=mad(in1.w,weight3,out3);\n"
" \n"
" weight_offset += 4;\n"
" }\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(COMPUTE_FLOAT4)0);\n"
" out1=fmax(out1,(COMPUTE_FLOAT4)0);\n"
" out2=fmax(out2,(COMPUTE_FLOAT4)0);\n"
" out3=fmax(out3,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out1=clamp(out1,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out2=clamp(out2,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out3=clamp(out3,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" int out_offset=(((out_b_idx+out_c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
"#ifdef BLOCK_LEAVE\n"
" const int remain=out_hw.x-out_h_idx;\n"
" if(remain >= 2){\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out1),out_hw.y,output+out_offset);\n"
" }else if(remain == 1){\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" }\n"
" #ifdef CHANNEL_LEAVE\n"
" if(out_c_idx+1 >= out_c_blocks){\n"
" return;\n"
" }\n"
" #endif\n"
" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
" if(remain >= 2){\n"
" vstore4(CONVERT_FLOAT4(out2),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out3),out_hw.y,output+out_offset);\n"
" }else if(remain == 1){\n"
" vstore4(CONVERT_FLOAT4(out2),0,output+out_offset);\n"
" }\n"
"#else\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out1),out_hw.y,output+out_offset);\n"
" #ifdef CHANNEL_LEAVE\n"
" if(out_c_idx+1 >= out_c_blocks){\n"
" return;\n"
" }\n"
" #endif\n"
" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
" vstore4(CONVERT_FLOAT4(out2),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out3),out_hw.y,output+out_offset);\n"
"#endif\n"
"}\n"
"__kernel\n"
"void conv_2d_c8h1w4(GLOBAL_SIZE_2_DIMS\n"
" __global const FLOAT *input,\n"
" __global const FLOAT *weight,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT *output,\n"
" __private const int2 in_hw,\n"
" __private const int inChannel,\n"
" __private const int in_c_blocks,\n"
" __private const int batch,\n"
" __private const int2 out_hw,\n"
" __private const int2 filter_hw,\n"
" __private const int2 stride_hw,\n"
" __private const int2 pad_hw,\n"
" __private const int2 dilate_hw,\n"
" __private const int out_w_blocks,\n"
" __private const int out_c_blocks,\n"
" __private const int out_h_blocks) {\n"
" const int out_c_w_idx=get_global_id(0); //c/4 w\n"
" const int out_b_h_idx=get_global_id(1); //b h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int out_c_idx=(out_c_w_idx/out_w_blocks) << 1;\n"
" const int out_w_idx=(out_c_w_idx % out_w_blocks) << 2;\n"
" const int out_b_idx=out_b_h_idx/out_hw.x;//equal to in_b_idx\n"
" const int out_h_idx=out_b_h_idx % out_hw.x;\n"
" \n"
" COMPUTE_FLOAT4 out0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias));\n"
" COMPUTE_FLOAT4 out1=out0;\n"
" COMPUTE_FLOAT4 out2=out0;\n"
" COMPUTE_FLOAT4 out3=out0;\n"
" \n"
" COMPUTE_FLOAT4 out4=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx+1,bias));\n"
" COMPUTE_FLOAT4 out5=out4;\n"
" COMPUTE_FLOAT4 out6=out4;\n"
" COMPUTE_FLOAT4 out7=out4;\n"
" const int in_w0_idx_base=mad24(out_w_idx,stride_hw.y,-pad_hw.y);\n"
" const int in_w1_idx_base=in_w0_idx_base+stride_hw.y;\n"
" const int in_w2_idx_base=in_w1_idx_base+stride_hw.y;\n"
" const int in_w3_idx_base=in_w2_idx_base+stride_hw.y;\n"
" const int in_h_idx_base=mad24(out_h_idx,stride_hw.x,-pad_hw.x);\n"
" \n"
" const int kh_start=select(0,(-in_h_idx_base+dilate_hw.x-1)/dilate_hw.x,in_h_idx_base<0);\n"
" const int in_h_idx_start=mad24(kh_start,dilate_hw.x,in_h_idx_base);\n"
" const int in_h_idx_end=min(mad24(filter_hw.x,dilate_hw.x,in_h_idx_base),in_hw.x);\n"
" \n"
" const int weight_oc_offset=filter_hw.x*filter_hw.y*4;\n"
" const int weight_ic_offset=out_c_blocks*weight_oc_offset;\n"
" for(ushort in_c_idx=0; in_c_idx<in_c_blocks; in_c_idx++) {\n"
" //weights NC4HW4 [1,4*icC4,ocC4*kh*kw,1] xic4\n"
" //index: [0,4*in_c_idx,out_c_idx*kh*kw+kh_start*kw+kw_start,0]\n"
" int weight_offset=((((4*in_c_idx+0)* out_c_blocks+out_c_idx) *filter_hw.x+kh_start)*filter_hw.y+0)*4;\n"
" for(int iy=in_h_idx_start; iy<in_h_idx_end; iy += dilate_hw.x) {\n"
" const int inp_offset_base=(((out_b_idx+in_c_idx*batch)*in_hw.x+iy)*in_hw.y+0)*4;\n"
" for(int fw=0; fw<filter_hw.y; fw++) {\n"
" const int in_w0_idx=fw*dilate_hw.y+in_w0_idx_base;\n"
" const int in_w1_idx=fw*dilate_hw.y+in_w1_idx_base;\n"
" const int in_w2_idx=fw*dilate_hw.y+in_w2_idx_base;\n"
" const int in_w3_idx=fw*dilate_hw.y+in_w3_idx_base;\n"
" COMPUTE_FLOAT4 in0=(in_w0_idx<0 || in_w0_idx >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_w0_idx,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in1=(in_w1_idx<0 || in_w1_idx >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_w1_idx,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in2=(in_w2_idx<0 || in_w2_idx >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_w2_idx,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in3=(in_w3_idx<0 || in_w3_idx >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(in_w3_idx,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 weight0=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset));\n"
" COMPUTE_FLOAT4 weight1=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_ic_offset));\n"
" COMPUTE_FLOAT4 weight2=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_ic_offset*2));\n"
" COMPUTE_FLOAT4 weight3=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_ic_offset*3));\n"
" out0=mad(in0.x,weight0,out0);\n"
" out0=mad(in0.y,weight1,out0);\n"
" out0=mad(in0.z,weight2,out0);\n"
" out0=mad(in0.w,weight3,out0);\n"
" \n"
" out1=mad(in1.x,weight0,out1);\n"
" out1=mad(in1.y,weight1,out1);\n"
" out1=mad(in1.z,weight2,out1);\n"
" out1=mad(in1.w,weight3,out1);\n"
" \n"
" out2=mad(in2.x,weight0,out2);\n"
" out2=mad(in2.y,weight1,out2);\n"
" out2=mad(in2.z,weight2,out2);\n"
" out2=mad(in2.w,weight3,out2);\n"
" \n"
" out3=mad(in3.x,weight0,out3);\n"
" out3=mad(in3.y,weight1,out3);\n"
" out3=mad(in3.z,weight2,out3);\n"
" out3=mad(in3.w,weight3,out3);\n"
" \n"
" weight0=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset));\n"
" weight1=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset));\n"
" weight2=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*2));\n"
" weight3=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*3));\n"
" \n"
" out4=mad(in0.x,weight0,out4);\n"
" out4=mad(in0.y,weight1,out4);\n"
" out4=mad(in0.z,weight2,out4);\n"
" out4=mad(in0.w,weight3,out4);\n"
" \n"
" out5=mad(in1.x,weight0,out5);\n"
" out5=mad(in1.y,weight1,out5);\n"
" out5=mad(in1.z,weight2,out5);\n"
" out5=mad(in1.w,weight3,out5);\n"
" \n"
" out6=mad(in2.x,weight0,out6);\n"
" out6=mad(in2.y,weight1,out6);\n"
" out6=mad(in2.z,weight2,out6);\n"
" out6=mad(in2.w,weight3,out6);\n"
" \n"
" out7=mad(in3.x,weight0,out7);\n"
" out7=mad(in3.y,weight1,out7);\n"
" out7=mad(in3.z,weight2,out7);\n"
" out7=mad(in3.w,weight3,out7);\n"
" \n"
" weight_offset += 4;\n"
" }\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(COMPUTE_FLOAT4)0);\n"
" out1=fmax(out1,(COMPUTE_FLOAT4)0);\n"
" out2=fmax(out2,(COMPUTE_FLOAT4)0);\n"
" out3=fmax(out3,(COMPUTE_FLOAT4)0);\n"
" out4=fmax(out4,(COMPUTE_FLOAT4)0);\n"
" out5=fmax(out5,(COMPUTE_FLOAT4)0);\n"
" out6=fmax(out6,(COMPUTE_FLOAT4)0);\n"
" out7=fmax(out7,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out1=clamp(out1,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out2=clamp(out2,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out3=clamp(out3,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out4=clamp(out4,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out5=clamp(out5,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out6=clamp(out6,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out7=clamp(out7,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" int out_offset=(((out_b_idx+out_c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
"#ifdef BLOCK_LEAVE\n"
" const int remain=out_hw.y-out_w_idx;\n"
" if(remain >= 4){\n"
" vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out0,out1,out2,out3)),0,output+out_offset);\n"
" }else if(remain == 3){\n"
" vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out0,out1)),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out2),2,output+out_offset);\n"
" }else if(remain == 2){\n"
" vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out0,out1)),0,output+out_offset);\n"
" }else if(remain == 1){\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n"
" }\n"
" #ifdef CHANNEL_LEAVE\n"
" if(out_c_idx+1 >= out_c_blocks)return;\n"
" #endif\n"
" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
" if(remain >= 4){\n"
" vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out4,out5,out6,out7)),0,output+out_offset);\n"
" }else if(remain == 3){\n"
" vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out4,out5)),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out6),2,output+out_offset);\n"
" }else if(remain == 2){\n"
" vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out4,out5)),0,output+out_offset);\n"
" }else if(remain == 1){\n"
" vstore4(CONVERT_FLOAT4(out4),0,output+out_offset);\n"
" }\n"
"#else\n"
" vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out0,out1,out2,out3)),0,output+out_offset);\n"
" #ifdef CHANNEL_LEAVE\n"
" if(out_c_idx+1 >= out_c_blocks)return;\n"
" #endif\n"
" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n"
" vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out4,out5,out6,out7)),0,output+out_offset);\n"
"#endif\n"
"}\n"
;
#endif
const char* buffer_to_image = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0,__private const int global_size_dim1,\n"
"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"// convert kernel : from buffer(oi ) to image(oc,ic/4)\n"
"__kernel void conv2d1x1_opt_filter_buffer_to_image(GLOBAL_SIZE_2_DIMS __global const FLOAT *input_ptr,\n"
" __private const int input_channel,__private const int2 kernel_shape,__private const int ic_h_w_size,\n"
" __private const int height_width_size,__write_only image2d_t output) {\n"
" \n"
" int ic_4_idx=get_global_id(0); // ic/4\n"
" int oc_idx=get_global_id(1); // oc\n"
" DEAL_NON_UNIFORM_DIM2(ic_4_idx,oc_idx);\n"
" const int ic_idx=ic_4_idx*4;\n"
" const int buffer_offset=oc_idx*input_channel+ic_idx;\n"
" \n"
" FLOAT4 output_values=0;\n"
" if (ic_idx<input_channel) {\n"
" const int remain_channel=input_channel-ic_idx;\n"
" if (remain_channel >= 4) {\n"
" output_values.x=*(input_ptr+buffer_offset);\n"
" output_values.y=*(input_ptr+buffer_offset+1);\n"
" output_values.z=*(input_ptr+buffer_offset+2);\n"
" output_values.w=*(input_ptr+buffer_offset+3);\n"
" } else if (remain_channel == 3) {\n"
" output_values.x=*(input_ptr+buffer_offset);\n"
" output_values.y=*(input_ptr+buffer_offset+1);\n"
" output_values.z=*(input_ptr+buffer_offset+2);\n"
" output_values.w=0;\n"
" } else if (remain_channel == 2) {\n"
" output_values.x=*(input_ptr+buffer_offset);\n"
" output_values.y=*(input_ptr+buffer_offset+1);\n"
" output_values.z=0;\n"
" output_values.w=0;\n"
" } else if (remain_channel == 1) {\n"
" output_values.x=*(input_ptr+buffer_offset);\n"
" output_values.y=0;\n"
" output_values.z=0;\n"
" output_values.w=0;\n"
" }\n"
" }\n"
" WI_F(output,(int2)(ic_4_idx,oc_idx),output_values);\n"
"}\n"
"// convert kernel : from buffer(oihw) to image(oc/4 h w ,ic oc4)\n"
"__kernel void conv2d_filter_buffer_to_image(GLOBAL_SIZE_2_DIMS\n"
" #ifdef BUFFER_INP_FP32\n"
" __global const float *input_ptr,\n"
" #else\n"
" __global const FLOAT *input_ptr,\n"
" #endif\n"
" __private const int output_channel,__private const int2 kernel_shape,__private const int ic_h_w_size,\n"
" __private const int height_width_size,__write_only image2d_t output) {\n"
" int image_width_idx=get_global_id(0); // ic\n"
" int image_height_idx=get_global_id(1); // oc/4 h w\n"
" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n"
" const int input_channel_4_idx=image_width_idx;\n"
" const int output_channel_4_idx=(image_height_idx/height_width_size)*4;\n"
" const int height_width_idx=image_height_idx % height_width_size;\n"
" const int buffer_height_idx=height_width_idx/kernel_shape.y;\n"
" const int buffer_width_idx=height_width_idx % kernel_shape.y;\n"
" const int buffer_offset=output_channel_4_idx*ic_h_w_size+input_channel_4_idx*height_width_size +\n"
" buffer_height_idx*kernel_shape.y+buffer_width_idx;\n"
" FLOAT4 output_values=0;\n"
" if (output_channel_4_idx<output_channel) {\n"
" const int remain_channel=output_channel-output_channel_4_idx;\n"
" if (remain_channel >= 4) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(FLOAT)(*(input_ptr+offset));\n"
" offset=mad24(1,ic_h_w_size,offset);\n"
" output_values.y=(FLOAT)(*(input_ptr+offset));\n"
" offset += ic_h_w_size;\n"
" output_values.z=(FLOAT)(*(input_ptr+offset));\n"
" offset += ic_h_w_size;\n"
" output_values.w=(FLOAT)(*(input_ptr+offset));\n"
" } else if (remain_channel == 3) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(FLOAT)(*(input_ptr+offset));\n"
" offset=mad24(1,ic_h_w_size,offset);\n"
" output_values.y=(FLOAT)(*(input_ptr+offset));\n"
" offset += ic_h_w_size;\n"
" output_values.z=(FLOAT)(*(input_ptr+offset));\n"
" } else if (remain_channel == 2) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(FLOAT)(*(input_ptr+offset));\n"
" offset=mad24(1,ic_h_w_size,offset);\n"
" output_values.y=(FLOAT)(*(input_ptr+offset));\n"
" } else if (remain_channel == 1) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(FLOAT)(*(input_ptr+offset));\n"
" }\n"
" }\n"
" WI_F(output,(int2)(image_width_idx,image_height_idx),output_values);\n"
"}\n"
"// only for debug\n"
"// convert kernel : from image(oc/4 h w ,ic oc4) to buffer(oihw)\n"
"__kernel void conv2d_filter_image_to_buffer(GLOBAL_SIZE_2_DIMS __global FLOAT *output_ptr,\n"
" __private const int output_channel,__private const int2 kernel_shape,\n"
" __private const int ic_h_w_size,\n"
" __private const int height_width_size,__read_only image2d_t input_ptr) {\n"
" int image_width_idx=get_global_id(0);\n"
" int image_height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n"
" const int input_channel_4_idx=image_width_idx;\n"
" const int output_channel_4_idx=image_height_idx/height_width_size*4;\n"
" const int height_width_idx=image_height_idx % height_width_size;\n"
" const int buffer_height_idx=height_width_idx/kernel_shape.y;\n"
" const int buffer_width_idx=height_width_idx % kernel_shape.y;\n"
" const int buffer_offset=output_channel_4_idx*ic_h_w_size+input_channel_4_idx*height_width_size +\n"
" buffer_height_idx*kernel_shape.y+buffer_width_idx;\n"
" if (output_channel_4_idx<output_channel) {\n"
" int2 coord=(int2)(image_width_idx,image_height_idx);\n"
" FLOAT4 values=RI_F(input_ptr,SAMPLER,coord);\n"
" const int remain_channel=(output_channel-output_channel_4_idx);\n"
" if (remain_channel >= 4) {\n"
" int offset=buffer_offset;\n"
" output_ptr[offset]=values.x;\n"
" offset=mad24(1,ic_h_w_size,offset);\n"
" output_ptr[offset]=values.y;\n"
" offset += ic_h_w_size;\n"
" output_ptr[offset]=values.z;\n"
" offset += ic_h_w_size;\n"
" output_ptr[offset]=values.w;\n"
" } else if (remain_channel == 3) {\n"
" int offset=buffer_offset;\n"
" output_ptr[offset]=values.x;\n"
" offset=mad24(1,ic_h_w_size,offset);\n"
" output_ptr[offset]=values.y;\n"
" offset += ic_h_w_size;\n"
" output_ptr[offset]=values.z;\n"
" } else if (remain_channel == 2) {\n"
" int offset=buffer_offset;\n"
" output_ptr[offset]=values.x;\n"
" offset=mad24(1,ic_h_w_size,offset);\n"
" output_ptr[offset]=values.y;\n"
" } else if (remain_channel == 1) {\n"
" int offset=buffer_offset;\n"
" output_ptr[offset]=values.x;\n"
" }\n"
" }\n"
"}\n"
"// convert kernel from buffer(mihw) to image(ic/4,ic4 h w m)\n"
"// but now dw only support m == 1\n"
"__kernel void dw_filter_buffer_to_image(GLOBAL_SIZE_2_DIMS\n"
" #ifdef BUFFER_INP_FP32\n"
" __global const float *input_ptr,\n"
" #else\n"
" __global const FLOAT *input_ptr,\n"
" #endif\n"
" __private const int4 kernel_shape,\n"
" __private const int height_width_size,__write_only image2d_t output) {\n"
" const int image_width_idx=get_global_id(0);\n"
" const int image_height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n"
" FLOAT4 output_values=0;\n"
" if (kernel_shape.x == 1) {\n"
" const int input_channel_4_idx=image_height_idx*4;\n"
" const int buffer_height_idx=image_width_idx/kernel_shape.w;\n"
" const int buffer_width_idx=image_width_idx % kernel_shape.w;\n"
" const int buffer_offset =\n"
" mad24(mad24(input_channel_4_idx,kernel_shape.z,buffer_height_idx),kernel_shape.w,buffer_width_idx);\n"
" const int remain_channel=kernel_shape.y-input_channel_4_idx;\n"
" if (input_channel_4_idx<kernel_shape.y) {\n"
" if (remain_channel >= 4) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(FLOAT)(*(input_ptr+offset));\n"
" offset += height_width_size;\n"
" output_values.y=(FLOAT)(*(input_ptr+offset));\n"
" offset += height_width_size;\n"
" output_values.z=(FLOAT)(*(input_ptr+offset));\n"
" offset += height_width_size;\n"
" output_values.w=(FLOAT)(*(input_ptr+offset));\n"
" } else if (remain_channel == 3) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(FLOAT)(*(input_ptr+offset));\n"
" offset += height_width_size;\n"
" output_values.y=(FLOAT)(*(input_ptr+offset));\n"
" offset += height_width_size;\n"
" output_values.z=(FLOAT)(*(input_ptr+offset));\n"
" } else if (remain_channel == 2) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(FLOAT)(*(input_ptr+offset));\n"
" offset += height_width_size;\n"
" output_values.y=(FLOAT)(*(input_ptr+offset));\n"
" } else if (remain_channel == 1) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(FLOAT)(*(input_ptr+offset));\n"
" }\n"
" }\n"
" }\n"
" WI_F(output,(int2)(image_width_idx,image_height_idx),output_values);\n"
"}\n"
"__kernel void nc4hw4_buffer_to_image(GLOBAL_SIZE_2_DIMS\n"
" __global const INPUT_TYPE *input_ptr,\n"
" __private const int2 output_shape,\n"
" __private const int batch_size,__write_only image2d_t output) {\n"
" int image_width_idx=get_global_id(0);\n"
" int image_height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n"
" const int batch_idx=image_height_idx/output_shape.x;\n"
" const int height_idx=image_height_idx % output_shape.x;\n"
" const int width_idx=image_width_idx % output_shape.y;\n"
" const int channel_block_idx=image_width_idx/output_shape.y;\n"
" int buffer_offset =\n"
" (((batch_idx+channel_block_idx*batch_size)*output_shape.x+height_idx)*output_shape.y+width_idx)*4;\n"
" int2 coord=(int2)(image_width_idx,image_height_idx);\n"
" WI_DATA(output,coord,CONVERT_OUTPUT_I4(vload4(0,input_ptr+buffer_offset)));\n"
"}\n"
"__kernel void image_to_nc4hw4_buffer(GLOBAL_SIZE_2_DIMS\n"
" __global OUTPUT_TYPE *output,\n"
" __private const int2 output_shape,\n"
" __private const int batch_size,\n"
" __read_only image2d_t input_ptr) {\n"
" int image_width_idx=get_global_id(0);\n"
" int image_height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n"
" const int batch_idx=image_height_idx/output_shape.x;\n"
" const int height_idx=image_height_idx % output_shape.x;\n"
" const int width_idx=image_width_idx % output_shape.y;\n"
" int channel_block_idx=image_width_idx/output_shape.y;\n"
" int buffer_offset =\n"
" (((batch_idx+channel_block_idx*batch_size)*output_shape.x+height_idx)*output_shape.y+width_idx)*4;\n"
" int2 coord=(int2)(image_width_idx,image_height_idx);\n"
" vstore4(CONVERT_OUTPUT4(RI_DATA(input_ptr,SAMPLER,coord)),0,output+buffer_offset);\n"
"}\n"
"__kernel void nhwc_buffer_to_image(GLOBAL_SIZE_2_DIMS\n"
" __global const INPUT_TYPE *input_ptr,\n"
" __private const int height,\n"
" __private const int width,__private const int channels,\n"
" __write_only image2d_t output) {\n"
" int image_width_idx=get_global_id(0);\n"
" int image_height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n"
" const int batch_idx=image_height_idx/height;\n"
" const int height_idx=image_height_idx % height;\n"
" const int width_idx=image_width_idx % width;\n"
" const int channel_4_idx=(image_width_idx/width) << 2;\n"
" const int buffer_offset=((batch_idx*height+height_idx)*width+width_idx)*channels+channel_4_idx;\n"
" const int remain_channel=channels-channel_4_idx;\n"
" INPUT_TYPE4 values=vload4(0,input_ptr+buffer_offset);\n"
" if (remain_channel == 3) {\n"
" values.w=0;\n"
" } else if (remain_channel == 2) {\n"
" values.z=0;\n"
" values.w=0;\n"
" } else if (remain_channel == 1) {\n"
" values.y=0;\n"
" values.z=0;\n"
" values.w=0;\n"
" }\n"
" WI_DATA(output,(int2)(image_width_idx,image_height_idx),CONVERT_OUTPUT_I4(values));\n"
"}\n"
"__kernel void nchw_buffer_to_image(GLOBAL_SIZE_2_DIMS\n"
" __global const INPUT_TYPE *input_ptr,\n"
" __private const int height,__private const int width,__private const int channels,\n"
" __write_only image2d_t output) {\n"
" int image_width_idx=get_global_id(0);\n"
" int image_height_idx=get_global_id(1);\n"
" \n"
" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n"
" const int batch_idx=image_height_idx/height;\n"
" const int height_idx=image_height_idx % height;\n"
" const int width_idx=image_width_idx % width;\n"
" const int channel_4_idx=image_width_idx/width << 2;\n"
" const int buffer_offset=((batch_idx*channels+channel_4_idx)*height+height_idx)*width+width_idx;\n"
" const int remain_channel=channels-channel_4_idx;\n"
" const int height_width_size=height*width;\n"
" INPUT_TYPE4 output_values=0;\n"
" if (remain_channel >= 4) {\n"
" int offset=buffer_offset;\n"
" output_values.x=*(input_ptr+offset);\n"
" offset += height_width_size;\n"
" output_values.y=*(input_ptr+offset);\n"
" offset += height_width_size;\n"
" output_values.z=*(input_ptr+offset);\n"
" offset += height_width_size;\n"
" output_values.w=*(input_ptr+offset);\n"
" } else if (remain_channel == 3) {\n"
" int offset=buffer_offset;\n"
" output_values.x=*(input_ptr+offset);\n"
" offset += height_width_size;\n"
" output_values.y=*(input_ptr+offset);\n"
" offset += height_width_size;\n"
" output_values.z=*(input_ptr+offset);\n"
" } else if (remain_channel == 2) {\n"
" int offset=buffer_offset;\n"
" output_values.x=*(input_ptr+offset);\n"
" offset += height_width_size;\n"
" output_values.y=*(input_ptr+offset);\n"
" } else if (remain_channel == 1) {\n"
" int offset=buffer_offset;\n"
" output_values.x=*(input_ptr+offset);\n"
" }\n"
" WI_DATA(output,(int2)(image_width_idx,image_height_idx),CONVERT_OUTPUT_I4(output_values));\n"
"}\n"
"__kernel void image_to_nhwc_buffer(GLOBAL_SIZE_2_DIMS\n"
" __global OUTPUT_TYPE *output,\n"
" __private const int height,__private const int width,\n"
" __private const int channels,\n"
" __read_only image2d_t input_ptr) {\n"
" int image_width_idx=get_global_id(0);\n"
" int image_height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n"
" const int batch_idx=image_height_idx/height;\n"
" const int height_idx=image_height_idx % height;\n"
" const int width_idx=image_width_idx % width;\n"
" const int channel_4_idx=(image_width_idx/width) << 2;\n"
" const int buffer_offset=((batch_idx*height+height_idx)*width+width_idx)*channels+channel_4_idx;\n"
" int2 coord=(int2)(image_width_idx,image_height_idx);\n"
" \n"
" INPUT_TYPE_I4 values=RI_DATA(input_ptr,SAMPLER,coord);\n"
" const int remain_channel=channels-channel_4_idx;\n"
" if (remain_channel >= 4) {\n"
" vstore4(CONVERT_OUTPUT4(values),0,output+buffer_offset);\n"
" } else if (remain_channel == 3) {\n"
" int offset=buffer_offset;\n"
" output[offset]=(OUTPUT_TYPE)values.x;\n"
" offset++;\n"
" output[offset]=(OUTPUT_TYPE)values.y;\n"
" offset++;\n"
" output[offset]=(OUTPUT_TYPE)values.z;\n"
" } else if (remain_channel == 2) {\n"
" int offset=buffer_offset;\n"
" output[offset]=(OUTPUT_TYPE)values.x;\n"
" offset++;\n"
" output[offset]=(OUTPUT_TYPE)values.y;\n"
" } else if (remain_channel == 1) {\n"
" int offset=buffer_offset;\n"
" output[offset]=(OUTPUT_TYPE)values.x;\n"
" }\n"
"}\n"
"__kernel void image_to_nchw_buffer(GLOBAL_SIZE_2_DIMS\n"
" __global OUTPUT_TYPE *output,\n"
" __private const int height,__private const int width,\n"
" __private const int channels,\n"
" __read_only image2d_t input_ptr) {\n"
" int image_width_idx=get_global_id(0);\n"
" int image_height_idx=get_global_id(1);\n"
" \n"
" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n"
" const int batch_idx=image_height_idx/height;\n"
" const int height_idx=image_height_idx % height;\n"
" const int width_idx=image_width_idx % width;\n"
" int channel_4_idx=(image_width_idx/width)*4;\n"
" int buffer_offset=((batch_idx*channels+channel_4_idx)*height+height_idx)*width+width_idx;\n"
" \n"
" INPUT_TYPE_I4 values=RI_DATA(input_ptr,SAMPLER,(int2)(image_width_idx,image_height_idx));\n"
" const int height_width_size=height*width;\n"
" const int remain_channel=channels-channel_4_idx;\n"
" if (remain_channel >= 4) {\n"
" int offset=buffer_offset;\n"
" output[offset]=(OUTPUT_TYPE)values.x;\n"
" offset += height_width_size;\n"
" output[offset]=(OUTPUT_TYPE)values.y;\n"
" offset += height_width_size;\n"
" output[offset]=(OUTPUT_TYPE)values.z;\n"
" offset += height_width_size;\n"
" output[offset]=(OUTPUT_TYPE)values.w;\n"
" } else if (remain_channel == 3) {\n"
" int offset=buffer_offset;\n"
" output[offset]=(OUTPUT_TYPE)values.x;\n"
" offset += height_width_size;\n"
" output[offset]=(OUTPUT_TYPE)values.y;\n"
" offset += height_width_size;\n"
" output[offset]=(OUTPUT_TYPE)values.z;\n"
" } else if (remain_channel == 2) {\n"
" int offset=buffer_offset;\n"
" output[offset]=(OUTPUT_TYPE)values.x;\n"
" offset += height_width_size;\n"
" output[offset]=(OUTPUT_TYPE)values.y;\n"
" } else if (remain_channel == 1) {\n"
" int offset=buffer_offset;\n"
" output[offset]=(OUTPUT_TYPE)values.x;\n"
" }\n"
"}\n"
"// convert arg as 4 alignment\n"
"__kernel void arg_buffer_to_image(GLOBAL_SIZE_2_DIMS __global const INPUT_TYPE *input_ptr,__private const int count,\n"
" __write_only image2d_t output) {\n"
" int image_width_idx=get_global_id(0);\n"
" int image_height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n"
" const int buffer_4_offset=image_width_idx << 2;\n"
" const int remain=count-buffer_4_offset;\n"
" int offset=buffer_4_offset;\n"
" INPUT_TYPE4 values=0;\n"
" if (remain >= 4) {\n"
" values=vload4(0,input_ptr+offset);\n"
" } else if (remain == 3) {\n"
" values.x=*(input_ptr+offset);\n"
" offset++;\n"
" values.y=*(input_ptr+offset);\n"
" offset++;\n"
" values.z=*(input_ptr+offset);\n"
" } else if (remain == 2) {\n"
" values.x=*(input_ptr+offset);\n"
" offset++;\n"
" values.y=*(input_ptr+offset);\n"
" } else if (remain == 1) {\n"
" values.x=*(input_ptr+offset);\n"
" }\n"
" WI_DATA(output,(int2)(image_width_idx,image_height_idx),CONVERT_OUTPUT_I4(values));\n"
"}\n"
"// only for debug\n"
"__kernel void arg_image_to_buffer(GLOBAL_SIZE_2_DIMS __global OUTPUT_TYPE *output,__private const int count,\n"
" __read_only image2d_t input_ptr) {\n"
" int image_width_idx=get_global_id(0);\n"
" int image_height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n"
" const int buffer_4_offset=image_width_idx << 2;\n"
" int2 coord=(int2)(image_width_idx,image_height_idx);\n"
" INPUT_TYPE_I4 values=RI_DATA(input_ptr,SAMPLER,coord);\n"
" const int remain=count-buffer_4_offset;\n"
" if (remain<4) {\n"
" switch (remain) {\n"
" case 3:\n"
" output[buffer_4_offset+2]=(OUTPUT_TYPE)values.s2;\n"
" case 2:\n"
" output[buffer_4_offset+1]=(OUTPUT_TYPE)values.s1;\n"
" case 1:\n"
" output[buffer_4_offset]=(OUTPUT_TYPE)values.s0;\n"
" }\n"
" } else {\n"
" vstore4(CONVERT_OUTPUT4(values),0,output+buffer_4_offset);\n"
" }\n"
" if (remain >= 4) {\n"
" vstore4(CONVERT_OUTPUT4(values),0,output+buffer_4_offset);\n"
" } else if (remain == 3) {\n"
" int offset=buffer_4_offset;\n"
" output[offset]=(OUTPUT_TYPE)values.x;\n"
" offset++;\n"
" output[offset]=(OUTPUT_TYPE)values.y;\n"
" offset++;\n"
" output[offset]=(OUTPUT_TYPE)values.z;\n"
" } else if (remain == 2) {\n"
" int offset=buffer_4_offset;\n"
" output[offset]=(OUTPUT_TYPE)values.x;\n"
" offset++;\n"
" output[offset]=(OUTPUT_TYPE)values.y;\n"
" } else if (remain == 1) {\n"
" int offset=buffer_4_offset;\n"
" output[offset]=(OUTPUT_TYPE)values.x;\n"
" }\n"
"}\n"
;
const char* winogradTransformDest2_3_1 = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void winogradTransformDest(__read_only image2d_t uInput,// 0\n"
" __read_only image2d_t uBias,__write_only image2d_t uOutput,\n"
" __private const int unitWidth,// 3\n"
" __private const int unitHeight,__private const int dstWidth,\n"
" __private const int dstHeight,// 6\n"
" __private const int dstChannelC4,__private const int batchOffset) {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1));\n"
" if (pos.x<unitWidth*unitHeight && pos.y<dstChannelC4) {\n"
" int unitWidth_idx=pos.x % unitWidth;\n"
" int unitHeight_idx=pos.x/unitWidth;\n"
" int srcY=pos.y*unitHeight+unitHeight_idx;\n"
" FLOAT4 bias=RI_F(uBias,SAMPLER,(int2)(pos.y,0));\n"
" {\n"
" int oyStart=unitHeight_idx*2;\n"
" int oxStart=unitWidth_idx*2;\n"
" FLOAT4 S00=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*0,srcY));\n"
" FLOAT4 S10=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*1,srcY));\n"
" FLOAT4 S20=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*2,srcY));\n"
" FLOAT4 S30=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*3,srcY));\n"
" FLOAT4 S01=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*4,srcY));\n"
" FLOAT4 S11=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*5,srcY));\n"
" FLOAT4 S21=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*6,srcY));\n"
" FLOAT4 S31=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*7,srcY));\n"
" FLOAT4 S02=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*8,srcY));\n"
" FLOAT4 S12=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*9,srcY));\n"
" FLOAT4 S22=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*10,srcY));\n"
" FLOAT4 S32=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*11,srcY));\n"
" FLOAT4 S03=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*12,srcY));\n"
" FLOAT4 S13=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*13,srcY));\n"
" FLOAT4 S23=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*14,srcY));\n"
" FLOAT4 S33=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*15,srcY));\n"
" FLOAT4 m00=+S00+S01+S02;\n"
" FLOAT4 m10=+S10+S11+S12;\n"
" FLOAT4 m20=+S20+S21+S22;\n"
" FLOAT4 m30=+S30+S31+S32;\n"
" FLOAT4 m01=+S01-S02+S03;\n"
" FLOAT4 m11=+S11-S12+S13;\n"
" FLOAT4 m21=+S21-S22+S23;\n"
" FLOAT4 m31=+S31-S32+S33;\n"
" {\n"
" int ox=oxStart+0;\n"
" int oy=oyStart+0;\n"
" if (ox<dstWidth && oy<dstHeight) {\n"
" int imageOx=ox+pos.y*dstWidth;\n"
" int imageOy=oy+batchOffset*dstHeight;\n"
" FLOAT4 res=bias+m00+m10+m20;\n"
"#ifdef RELU\n"
" res=max(res,(FLOAT4)(0));\n"
"#endif\n"
"#ifdef RELU6\n"
" res=clamp(res,(FLOAT4)(0),(FLOAT4)(6));\n"
"#endif\n"
" WI_F(uOutput,(int2)(imageOx,imageOy),res);\n"
" }\n"
" }\n"
" {\n"
" int ox=oxStart+1;\n"
" int oy=oyStart+0;\n"
" if (ox<dstWidth && oy<dstHeight) {\n"
" int imageOx=ox+pos.y*dstWidth;\n"
" int imageOy=oy+batchOffset*dstHeight;\n"
" FLOAT4 res=bias+m10-m20+m30;\n"
"#ifdef RELU\n"
" res=max(res,(FLOAT4)(0));\n"
"#endif\n"
"#ifdef RELU6\n"
" res=clamp(res,(FLOAT4)(0),(FLOAT4)(6));\n"
"#endif\n"
" WI_F(uOutput,(int2)(imageOx,imageOy),res);\n"
" }\n"
" }\n"
" {\n"
" int ox=oxStart+0;\n"
" int oy=oyStart+1;\n"
" if (ox<dstWidth && oy<dstHeight) {\n"
" int imageOx=ox+pos.y*dstWidth;\n"
" int imageOy=oy+batchOffset*dstHeight;\n"
" FLOAT4 res=bias+m01+m11+m21;\n"
"#ifdef RELU\n"
" res=max(res,(FLOAT4)(0));\n"
"#endif\n"
"#ifdef RELU6\n"
" res=clamp(res,(FLOAT4)(0),(FLOAT4)(6));\n"
"#endif\n"
" WI_F(uOutput,(int2)(imageOx,imageOy),res);\n"
" }\n"
" }\n"
" {\n"
" int ox=oxStart+1;\n"
" int oy=oyStart+1;\n"
" if (ox<dstWidth && oy<dstHeight) {\n"
" int imageOx=ox+pos.y*dstWidth;\n"
" int imageOy=oy+batchOffset*dstHeight;\n"
" FLOAT4 res=bias+m11-m21+m31;\n"
"#ifdef RELU\n"
" res=max(res,(FLOAT4)(0));\n"
"#endif\n"
"#ifdef RELU6\n"
" res=clamp(res,(FLOAT4)(0),(FLOAT4)(6));\n"
"#endif\n"
" WI_F(uOutput,(int2)(imageOx,imageOy),res);\n"
" }\n"
" }\n"
" }\n"
" }\n"
"}\n"
;
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* layernorm_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"__kernel void layernorm_buf(__private int global_dim0,__private int global_dim1,\n"
" __global const FLOAT*input,\n"
" __global FLOAT*output,\n"
" __private const int inside,\n"
"#ifdef GAMMA_BETA\n"
" __global const FLOAT *gamma,\n"
" __global const FLOAT *beta,\n"
"#endif\n"
" __private float epsilon){\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1));\n"
"#if LOCAL_SIZE>1\n"
" float local sum[LOCAL_SIZE];\n"
" if (pos.x<global_dim0 && pos.y<global_dim1) {\n"
" const int lid=get_local_id(0);\n"
" const int offset=pos.y*inside;\n"
" const int inside_v4=(inside+3) >> 2;\n"
" #ifdef PACK_LEAVE\n"
" const int loop=inside_v4-1;\n"
" const int inside_remain=inside-((inside_v4-1) << 2);\n"
" #else\n"
" const int loop=inside_v4;\n"
" #endif\n"
" \n"
" float4 in_sum=0;\n"
" int index=lid;\n"
" #ifdef RMSNORM\n"
" float4 mean=(float4)0;\n"
" #else\n"
" for(; index<loop; index+=LOCAL_SIZE){\n"
" float4 in=convert_float4(vload4(index,input+offset));\n"
" in_sum += in;\n"
" }\n"
" sum[lid]=in_sum.x+in_sum.y+in_sum.z+ in_sum.w;\n"
" \n"
" #ifdef PACK_LEAVE\n"
" if(index == inside_v4-1) {\n"
" for(int i=0; i<inside_remain; ++i)\n"
" float in=input[offset+index*4+i];\n"
" sum[lid]=sum[lid]+in;\n"
" }\n"
" }\n"
" #endif\n"
" \n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=sum[lid]+sum[lid+i];\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" \n"
" float4 mean=sum[0]/(float4)inside;\n"
" #endif\n"
" in_sum=0;\n"
" index=lid;\n"
" for(; index<loop; index+=LOCAL_SIZE){\n"
" float4 in=convert_float4(vload4(index,input+offset));\n"
" in_sum += (in-mean)*(in-mean);\n"
" }\n"
" sum[lid]=in_sum.x+in_sum.y+in_sum.z+in_sum.w;\n"
" #ifdef PACK_LEAVE\n"
" if(index == inside_v4-1) {\n"
" for(int i=0; i<inside_remain; ++i)\n"
" float in=input[offset+index*4+i];\n"
" in=(in-mean)*(in-mean);\n"
" sum[lid]=sum[lid]+in;\n"
" }\n"
" }\n"
" #endif\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=sum[lid]+sum[lid+i];\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" float4 square_sum=sum[0]/(float4)inside;\n"
" float4 value=(float4)1.0f/(float4)sqrt(square_sum+(float4)epsilon);\n"
" index=lid;\n"
" for(; index<loop; index+=LOCAL_SIZE){\n"
" float4 in=convert_float4(vload4(index,input+offset));\n"
" #ifdef GAMMA_BETA\n"
" float4 out=(in-mean)*value*convert_float4(vload4(index,gamma))+convert_float4(vload4(index,beta));\n"
" #else\n"
" float4 out=(in-mean)*value;\n"
" #endif\n"
" vstore4(CONVERT_FLOAT4(out),index,output+offset);\n"
" }\n"
" #ifdef PACK_LEAVE\n"
" if(index == inside_v4-1) {\n"
" for(int i=0; i<inside_remain; ++i){\n"
" float in=input[offset+index*4+i];\n"
" #ifdef GAMMA_BETA\n"
" float out=(in-mean.x)*value.x*(float)gamma[index*4+i]+(float)beta[index*4+i];\n"
" #else\n"
" float out=(in-mean.x)*value.x;\n"
" #endif\n"
" output[offset+index*4+i]=out;\n"
" }\n"
" }\n"
" #endif\n"
" }\n"
"#else\n"
" if (pos.x<global_dim0 && pos.y<global_dim1) {\n"
" const int offset=pos.y*inside;\n"
" #ifdef RMSNORM\n"
" float mean=0;\n"
" #else\n"
" float in_sum=0;\n"
" for(int index=0; index<inside; index++){\n"
" in_sum += (float)input[offset+index];\n"
" }\n"
" float mean=in_sum/inside;\n"
" #endif\n"
" in_sum=0;\n"
" for(int index=0; index<inside; index++){\n"
" float in=(float)input[offset+index];\n"
" in_sum += (in-mean)*(in-mean);\n"
" }\n"
" float square_sum=in_sum/inside;\n"
" float value=1.0f/sqrt(square_sum+epsilon);\n"
" for(int i=0; i<inside; ++i){\n"
" float in=input[offset+i];\n"
" #ifdef GAMMA_BETA\n"
" float out=(in-mean)*value*(float)gamma[i]+(float)beta[i];\n"
" #else\n"
" float out=(in-mean)*value;\n"
" #endif\n"
" output[offset+i]=out;\n"
" }\n"
" }\n"
"#endif\n"
"}\n"
;
#endif
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* softmax_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define EXP exp\n"
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"__kernel void softmax_in1_buf(GLOBAL_SIZE_3_DIMS\n"
" __global const FLOAT *input,\n"
" __global FLOAT *output,\n"
" __private const int inside,\n"
" __private const int outside,\n"
" __private const int dim) {\n"
" const int x=get_global_id(0);\n"
" const int y=get_global_id(1); // inside=1\n"
" const int z=get_global_id(2); // outside\n"
" DEAL_NON_UNIFORM_DIM3(x,y,z);\n"
" \n"
" const int offset=z*dim+y;\n"
" const int dim4=(dim+3)/4;\n"
" const int loop_end=max(0,dim4-1);\n"
"#if SOFTMAX_LOCAL_SIZE >= 4\n"
" int lid=get_local_id(0);\n"
" COMPUTE_FLOAT local sum[SOFTMAX_LOCAL_SIZE];\n"
" // compute maxvalue\n"
" COMPUTE_FLOAT4 maxValue=(COMPUTE_FLOAT4)-FLT_MAX;\n"
" for (int i=lid; i<loop_end; i+=SOFTMAX_LOCAL_SIZE) {\n"
" maxValue=fmax(maxValue,CONVERT_COMPUTE_FLOAT4(vload4(i,input+offset)));\n"
" }\n"
" sum[lid]=fmax(fmax(fmax(maxValue.x,maxValue.y),maxValue.z),maxValue.w);\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=SOFTMAX_LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=fmax(sum[lid],sum[lid+i]);\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" maxValue.x=sum[0];\n"
" for(int i=loop_end << 2; i<dim; ++i){\n"
" maxValue.x=fmax(maxValue.x,(COMPUTE_FLOAT)(input[offset+i]));\n"
" }\n"
" // compute sumvalue\n"
" COMPUTE_FLOAT4 sumValue=(COMPUTE_FLOAT4)0;\n"
" for (int i=lid; i<loop_end; i+=SOFTMAX_LOCAL_SIZE) {\n"
" sumValue += exp(CONVERT_COMPUTE_FLOAT4(vload4(i,input+offset))-(COMPUTE_FLOAT4)maxValue.x);\n"
" }\n"
" sum[lid]=sumValue.x+sumValue.y+sumValue.z+sumValue.w;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=SOFTMAX_LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=sum[lid]+sum[lid+i];\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" sumValue.x=sum[0];\n"
" for(int i=loop_end << 2; i<dim; ++i){\n"
" sumValue.x += exp((COMPUTE_FLOAT)(input[offset+i])-maxValue.x);\n"
" }\n"
" \n"
" // store result\n"
" for(int i=lid; i<loop_end; i+=SOFTMAX_LOCAL_SIZE){\n"
" vstore4(CONVERT_FLOAT4(exp(CONVERT_COMPUTE_FLOAT4(vload4(i,input+offset))-(COMPUTE_FLOAT4)maxValue.x)/(COMPUTE_FLOAT4)sumValue.x),0,output+offset+i*4);\n"
" }\n"
" for(int i=loop_end << 2; i<dim; ++i){\n"
" output[offset+i]=(FLOAT)exp((COMPUTE_FLOAT)(input[offset+i])-maxValue.x)/sumValue.x;\n"
" }\n"
"#else\n"
" // compute maxvalue\n"
" COMPUTE_FLOAT4 maxValue=(COMPUTE_FLOAT4)-FLT_MAX;\n"
" for (int i=0; i<loop_end; i++) {\n"
" maxValue=fmax(maxValue,CONVERT_COMPUTE_FLOAT4(vload4(i,input+offset)));\n"
" }\n"
" maxValue.x=fmax(fmax(fmax(maxValue.x,maxValue.y),maxValue.z),maxValue.w);\n"
" for(int i=loop_end << 2; i<dim; ++i){\n"
" maxValue.x=fmax(maxValue.x,(COMPUTE_FLOAT)(input[offset+i]));\n"
" }\n"
" \n"
" // compute sumvalue\n"
" COMPUTE_FLOAT4 sumValue=(COMPUTE_FLOAT4)0;\n"
" for (int i=0; i<loop_end; i++) {\n"
" sumValue += exp(CONVERT_COMPUTE_FLOAT4(vload4(i,input+offset))-(COMPUTE_FLOAT4)maxValue.x);\n"
" }\n"
" sumValue.x=sumValue.x+sumValue.y+sumValue.z+sumValue.w;\n"
" for(int i=loop_end << 2; i<dim; ++i){\n"
" sumValue.x += exp((COMPUTE_FLOAT)(input[offset+i])-maxValue.x);\n"
" }\n"
" \n"
" // store result\n"
" for(int i=0; i<loop_end; i++){\n"
" vstore4(CONVERT_FLOAT4(exp(CONVERT_COMPUTE_FLOAT4(vload4(i,input+offset))-(COMPUTE_FLOAT4)maxValue.x)/(COMPUTE_FLOAT4)sumValue.x),0,output+offset+i*4);\n"
" }\n"
" for(int i=loop_end << 2; i<dim; ++i){\n"
" output[offset+i]=(FLOAT)exp((COMPUTE_FLOAT)(input[offset+i])-maxValue.x)/sumValue.x;\n"
" }\n"
"#endif\n"
"}\n"
"__kernel void softmax_buf(GLOBAL_SIZE_3_DIMS\n"
" __global const FLOAT *input,\n"
" __global FLOAT *output,\n"
" __private const int inside,\n"
" __private const int outside,\n"
" __private const int dim) {\n"
" const int x=get_global_id(0);\n"
" const int y=get_global_id(1); // inside\n"
" const int z=get_global_id(2); // outside\n"
" DEAL_NON_UNIFORM_DIM3(x,y,z);\n"
" \n"
" const int offset=z*dim*inside+y;\n"
"#if SOFTMAX_LOCAL_SIZE >= 4\n"
" int lid=get_local_id(0);\n"
" COMPUTE_FLOAT local sum[SOFTMAX_LOCAL_SIZE];\n"
" COMPUTE_FLOAT maxValue=(COMPUTE_FLOAT)-FLT_MAX;\n"
" for (int i=lid; i<dim; i+=SOFTMAX_LOCAL_SIZE) {\n"
" maxValue=fmax(maxValue,(COMPUTE_FLOAT)(input[offset+i*inside]));\n"
" }\n"
" sum[lid]=maxValue;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=SOFTMAX_LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=fmax(sum[lid],sum[lid+i]);\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" maxValue=sum[0];\n"
" COMPUTE_FLOAT sumValue=(COMPUTE_FLOAT)0;\n"
" for (int i=lid; i<dim; i+=SOFTMAX_LOCAL_SIZE) {\n"
" sumValue += exp((COMPUTE_FLOAT)(input[offset+i*inside])-maxValue);\n"
" }\n"
" sum[lid]=sumValue;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=SOFTMAX_LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=sum[lid]+sum[lid+i];\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" sumValue=sum[0];\n"
" for(int i=lid; i<dim; i+=SOFTMAX_LOCAL_SIZE){\n"
" output[offset+i*inside]=(FLOAT)exp((COMPUTE_FLOAT)(input[offset+i*inside])-maxValue)/sumValue;\n"
" }\n"
"#else\n"
" COMPUTE_FLOAT maxValue=(COMPUTE_FLOAT)-FLT_MAX;\n"
" for (int i=0; i<dim; i++) {\n"
" maxValue=fmax(maxValue,(COMPUTE_FLOAT)(input[offset+i*inside]));\n"
" }\n"
" COMPUTE_FLOAT sumValue=(COMPUTE_FLOAT)0;\n"
" for (int i=0; i<dim; i++) {\n"
" sumValue += exp((COMPUTE_FLOAT)(input[offset+i*inside])-maxValue);\n"
" }\n"
" for(int i=0; i<dim; i++){\n"
" output[offset+i*inside]=(FLOAT)exp((COMPUTE_FLOAT)(input[offset+i*inside])-maxValue)/sumValue;\n"
" }\n"
"#endif\n"
"}\n"
"__kernel void softmax_v4_buf(GLOBAL_SIZE_3_DIMS\n"
" __global const FLOAT *input,\n"
" __global FLOAT *output,\n"
" __private const int inside,\n"
" __private const int outside,\n"
" __private const int dim) {\n"
" const int x=get_global_id(0);\n"
" const int y=get_global_id(1); // inside\n"
" const int z=get_global_id(2); // outside\n"
" DEAL_NON_UNIFORM_DIM3(x,y,z);\n"
" \n"
" const int offset=z*dim*inside+(y << 2);\n"
"#if SOFTMAX_LOCAL_SIZE >= 4\n"
" int lid=get_local_id(0);\n"
" COMPUTE_FLOAT4 local sum[SOFTMAX_LOCAL_SIZE];\n"
" COMPUTE_FLOAT4 maxValue=(COMPUTE_FLOAT4)-FLT_MAX;\n"
" for (int i=lid; i<dim; i+=SOFTMAX_LOCAL_SIZE) {\n"
" maxValue=fmax(maxValue,CONVERT_COMPUTE_FLOAT4(vload4(0,input+offset+i*inside)));\n"
" }\n"
" sum[lid]=maxValue;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=SOFTMAX_LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=fmax(sum[lid],sum[lid+i]);\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" maxValue=sum[0];\n"
" COMPUTE_FLOAT4 sumValue=(COMPUTE_FLOAT4)0;\n"
" for (int i=lid; i<dim; i+=SOFTMAX_LOCAL_SIZE) {\n"
" sumValue += exp(CONVERT_COMPUTE_FLOAT4(vload4(0,input+offset+i*inside))-maxValue);\n"
" }\n"
" sum[lid]=sumValue;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=SOFTMAX_LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=sum[lid]+sum[lid+i];\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" sumValue=sum[0];\n"
" for(int i=lid; i<dim; i+=SOFTMAX_LOCAL_SIZE){\n"
" vstore4(CONVERT_FLOAT4(exp(CONVERT_COMPUTE_FLOAT4(vload4(0,input+offset+i*inside))-maxValue)/sumValue),0,output+offset+i*inside);\n"
" }\n"
"#else\n"
" COMPUTE_FLOAT4 maxValue=(COMPUTE_FLOAT4)-FLT_MAX;\n"
" for (int i=0; i<dim; i++) {\n"
" maxValue=fmax(maxValue,CONVERT_COMPUTE_FLOAT4(vload4(0,input+offset+i*inside)));\n"
" }\n"
" COMPUTE_FLOAT4 sumValue=(COMPUTE_FLOAT4)0;\n"
" for (int i=0; i<dim; i++) {\n"
" sumValue += exp(CONVERT_COMPUTE_FLOAT4(vload4(0,input+offset+i*inside))-maxValue);\n"
" }\n"
" for(int i=0; i<dim; i++){\n"
" vstore4(CONVERT_FLOAT4(exp(CONVERT_COMPUTE_FLOAT4(vload4(0,input+offset+i*inside))-maxValue)/sumValue),0,output+offset+i*inside);\n"
" }\n"
"#endif\n"
"}\n"
;
#endif
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* gather_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"__kernel void batch_gather_buf(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __global OUTPUT_TYPE* output,__global INPUT_TYPE* input,\n"
" #ifdef OFFSET_DST\n"
" __global int* offset_dst_ptr,\n"
" __private const int4 offset_dst_shape,// w,h,c,n\n"
" #endif\n"
" #ifdef OFFSET_SRC\n"
" __global int* offset_src_ptr,\n"
" __private const int4 offset_src_shape,// w,h,c,n\n"
" #endif\n"
" __private const int x_size,\n"
" __private const int4 stride_src,\n"
" __private const int4 stride_dst,\n"
" __private const int2 steps,\n"
" __private const int2 iters,\n"
" __private const int inputSize) {\n"
" int3 pos=(int3)(get_global_id(0),get_global_id(1),get_global_id(2));\n"
" \n"
" if (pos.x<global_dim0 && pos.y<global_dim1 && pos.z<global_dim2) {\n"
" \n"
" int x=pos.x % x_size;\n"
" int y=pos.x/x_size;\n"
" int2 index=(int2)(pos.z,pos.z);\n"
"#ifdef OFFSET_DST\n"
" index.x=offset_dst_ptr[pos.z];\n"
"#endif\n"
" \n"
"#ifdef OFFSET_SRC\n"
" index.y=offset_src_ptr[pos.z];\n"
"#endif\n"
" int2 offset=index*steps;\n"
" int src_offset=offset.y+stride_src.w+x*stride_src.x+y*stride_src.y+pos.y*stride_src.z;\n"
" int dst_offset=offset.x+stride_dst.w+x*stride_dst.x+y*stride_dst.y+pos.y*stride_dst.z;\n"
" if(offset.x >= 0){\n"
" if(offset.y >= 0 && offset.y<inputSize){\n"
" output[dst_offset]=(OUTPUT_TYPE)input[src_offset];\n"
" }else{\n"
" output[dst_offset]=(OUTPUT_TYPE)(0);\n"
" }\n"
" }\n"
" }\n"
"}\n"
;
#endif
#ifndef MNN_OPENCL_BUFFER_CLOSED
#ifdef MNN_SUPPORT_INTEL_SUBGROUP
const char* conv_2d_c16_subgroup_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n"
"#ifdef MNN_SUPPORT_FP16\n"
"#define GROUP_READ(ptr,offset) as_half(intel_sub_group_block_read_us((const __global ushort*)(ptr)+(offset)))\n"
"#define GROUP_READ2(ptr,offset) as_half2(intel_sub_group_block_read_us2((const __global ushort*)(ptr)+(offset)))\n"
"#define GROUP_READ4(ptr,offset) as_half4(intel_sub_group_block_read_us4((const __global ushort*)(ptr)+(offset)))\n"
"#define GROUP_READ8(ptr,offset) as_half8(intel_sub_group_block_read_us8((const __global ushort*)(ptr)+(offset)))\n"
"#define GROUP_WRITE(ptr,offset,val) intel_sub_group_block_write_us((const __global ushort*)(ptr)+(offset),as_ushort(val))\n"
"#define GROUP_WRITE2(ptr,offset,val) intel_sub_group_block_write_us2((const __global ushort*)(ptr)+(offset),as_ushort2(val))\n"
"#define GROUP_WRITE4(ptr,offset,val) intel_sub_group_block_write_us4((const __global ushort*)(ptr)+(offset),as_ushort4(val))\n"
"#define GROUP_WRITE8(ptr,offset,val) intel_sub_group_block_write_us8((const __global ushort*)(ptr)+(offset),as_ushort8(val))\n"
"#define GROUP_SHUFFLE(data,id) as_half(intel_sub_group_shuffle(as_ushort(data),id))\n"
"#define GROUP_SHUFFLE2(data,id) as_half2(intel_sub_group_shuffle(as_ushort2(data),id))\n"
"#define GROUP_SHUFFLE4(data,id) as_half4(intel_sub_group_shuffle(as_ushort4(data),id))\n"
"#define GROUP_SHUFFLE8(data,id) as_half8(intel_sub_group_shuffle(as_ushort8(data),id))\n"
"#else\n"
"#define GROUP_READ(ptr,offset) as_float(intel_sub_group_block_read((const __global uint*)(ptr)+(offset)))\n"
"#define GROUP_READ2(ptr,offset) as_float2(intel_sub_group_block_read2((const __global uint*)(ptr)+(offset)))\n"
"#define GROUP_READ4(ptr,offset) as_float4(intel_sub_group_block_read4((const __global uint*)(ptr)+(offset)))\n"
"#define GROUP_READ8(ptr,offset) as_float8(intel_sub_group_block_read8((const __global uint*)(ptr)+(offset)))\n"
"#define GROUP_WRITE(ptr,offset,val) intel_sub_group_block_write((const __global uint*)(ptr)+(offset),as_uint(val))\n"
"#define GROUP_WRITE2(ptr,offset,val) intel_sub_group_block_write2((const __global uint*)(ptr)+(offset),as_uint2(val))\n"
"#define GROUP_WRITE4(ptr,offset,val) intel_sub_group_block_write4((const __global uint*)(ptr)+(offset),as_uint4(val))\n"
"#define GROUP_WRITE8(ptr,offset,val) intel_sub_group_block_write8((const __global uint*)(ptr)+(offset),as_uint8(val))\n"
"#define GROUP_SHUFFLE(data,id) intel_sub_group_shuffle(data,id)\n"
"#define GROUP_SHUFFLE2(data,id) intel_sub_group_shuffle(data,id)\n"
"#define GROUP_SHUFFLE4(data,id) intel_sub_group_shuffle(data,id)\n"
"#define GROUP_SHUFFLE8(data,id) intel_sub_group_shuffle(data,id)\n"
"#endif\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void conv_2d_buf_subgroup_c16_c4_b2(\n"
" __global FLOAT* input,\n"
" __global FLOAT* output,\n"
" __global FLOAT* weights,\n"
" __global FLOAT* biases,\n"
" __private const int pad_width,\n"
" __private const int pad_height,\n"
" __private const int input_width,\n"
" __private const int input_height,\n"
" __private const int output_width,\n"
" __private const int output_height,\n"
" __private const int output_channel,\n"
" __private const int batch,\n"
" __private const int x_blocks,\n"
" __private const int input_pad_left,\n"
" __private const int input_pad_right,\n"
" __private const int output_pad_left,\n"
" __private const int output_pad_right\n"
") {\n"
" const int sglid=get_sub_group_local_id();\n"
" const int b=(uint)get_global_id(2);\n"
" const int xy=get_global_id(0);\n"
" const int x=(xy % x_blocks) << 1;\n"
" const int y=(xy/x_blocks);\n"
" const int lid1=(int)get_local_id(1);\n"
" const int feature_per_wg=(int)get_local_size(1)/SLM_DIV_FACTOR;\n"
" const int feature_sub_block=lid1/feature_per_wg;\n"
" const int feature_block=(int)get_group_id(1);\n"
" const int input_x=x*STRIDE_WIDTH-pad_width;\n"
" const int input_y=y*STRIDE_HEIGHT-pad_height;\n"
" const uint input_x_pitch=16;\n"
" const uint input_y_pitch=input_x_pitch*(input_pad_left+input_width+input_pad_right);\n"
" const uint input_fs_pitch=input_y_pitch*(input_height);\n"
" const uint input_b_pitch=input_fs_pitch*((INPUT_CHANNEL+15)/16);\n"
" const uint input_offset=b*input_b_pitch +\n"
" input_y*input_y_pitch +\n"
" (input_x+input_pad_left)*input_x_pitch;\n"
" const uint output_x_pitch=4;\n"
" const uint output_y_pitch=output_x_pitch*output_width;\n"
" const uint output_fs_pitch=output_y_pitch*output_height;\n"
" const uint output_b_pitch=output_fs_pitch*batch;\n"
" const uint output_offset=b*output_fs_pitch +\n"
" (feature_block << 2)*output_b_pitch +\n"
" y*output_y_pitch +\n"
" x*output_x_pitch;\n"
" const uint filter_isv_pitch=16;\n"
" const uint filter_x_pitch=16*16;\n"
" const uint filter_y_pitch=filter_x_pitch*FILTER_WIDTH;\n"
" const uint filter_is_pitch=filter_y_pitch*FILTER_HEIGHT;\n"
" const uint filter_os_pitch=filter_is_pitch*((INPUT_CHANNEL+15)/16);\n"
" const uint filter_offset=feature_block*filter_os_pitch;\n"
"#if SLM_DIV_FACTOR == 1\n"
" COMPUTE_FLOAT2 dst=(COMPUTE_FLOAT2)((GROUP_READ(biases,feature_block*16)));\n"
"#else\n"
" COMPUTE_FLOAT2 dst;\n"
" if (feature_sub_block == 0) {\n"
" dst=(COMPUTE_FLOAT2)((GROUP_READ(biases,feature_block*16)));\n"
" } else {\n"
" dst=(COMPUTE_FLOAT2)0;\n"
" }\n"
"#endif \n"
"#if SLM_DIV_FACTOR>1\n"
" __local COMPUTE_FLOAT2 sum[WORK_GROUP_SIZE];\n"
"#endif\n"
"#if SLM_DIV_FACTOR>1\n"
" for (int icb=feature_sub_block*IC_BLOCKS/SLM_DIV_FACTOR; icb<(feature_sub_block+1)*IC_BLOCKS/SLM_DIV_FACTOR; icb++) {\n"
"#else\n"
" for (int icb=0; icb<IC_BLOCKS; icb++) {\n"
"#endif \n"
" __attribute__((opencl_unroll_hint(FILTER_HEIGHT)))\n"
" for (int kh=0; kh<FILTER_HEIGHT; kh++) {\n"
" if (input_y+kh*DILATION_HEIGHT<0 || input_y+kh*DILATION_HEIGHT >= input_height)\n"
" continue;\n"
" FLOAT line_cache[INPUT_LINE_SIZE];\n"
" {\n"
" int xb=0;\n"
" for (; xb+8 <= INPUT_LINE_SIZE; xb += 8) {\n"
" COMPUTE_FLOAT8 tmp=CONVERT_COMPUTE_FLOAT8(GROUP_READ8(input,input_offset +\n"
" icb*input_fs_pitch +\n"
" kh*DILATION_HEIGHT*input_y_pitch +\n"
" xb*input_x_pitch));\n"
" \n"
" line_cache[xb+0]=tmp[0];\n"
" line_cache[xb+1]=tmp[1];\n"
" line_cache[xb+2]=tmp[2];\n"
" line_cache[xb+3]=tmp[3];\n"
" line_cache[xb+4]=tmp[4];\n"
" line_cache[xb+5]=tmp[5];\n"
" line_cache[xb+6]=tmp[6];\n"
" line_cache[xb+7]=tmp[7];\n"
" }\n"
" for (; xb+4 <= INPUT_LINE_SIZE; xb += 4) {\n"
" COMPUTE_FLOAT4 tmp=CONVERT_COMPUTE_FLOAT4(GROUP_READ4(input,input_offset +\n"
" icb*input_fs_pitch +\n"
" kh*DILATION_HEIGHT*input_y_pitch +\n"
" xb*input_x_pitch));\n"
" \n"
" line_cache[xb+0]=tmp[0];\n"
" line_cache[xb+1]=tmp[1];\n"
" line_cache[xb+2]=tmp[2];\n"
" line_cache[xb+3]=tmp[3];\n"
" }\n"
" for (; xb<INPUT_LINE_SIZE; xb++) {\n"
" line_cache[xb]=GROUP_READ(input,input_offset +\n"
" icb*input_fs_pitch +\n"
" kh*DILATION_HEIGHT*input_y_pitch +\n"
" xb*input_x_pitch);\n"
" }\n"
" }\n"
" __attribute__((opencl_unroll_hint(FILTER_WIDTH)))\n"
" for (int kw=0; kw<FILTER_WIDTH; kw++) {\n"
" FLOAT2 src;\n"
" __attribute__((opencl_unroll_hint(2)))\n"
" for (int i=0; i<2; i++) {\n"
"#if FILTER_WIDTH == 1 && DILATION_WIDTH == 1 && STRIDE_WIDTH == 1\n"
" src[i]=line_cache[i];\n"
"#else\n"
" src[i]=line_cache[kw*DILATION_WIDTH+STRIDE_WIDTH*i];\n"
"#endif\n"
" }\n"
" COMPUTE_FLOAT8 weight0=CONVERT_COMPUTE_FLOAT8(GROUP_READ8(weights,filter_offset +\n"
" icb*filter_is_pitch +\n"
" kh*filter_y_pitch +\n"
" kw*filter_x_pitch));\n"
" COMPUTE_FLOAT8 weight1=CONVERT_COMPUTE_FLOAT8(GROUP_READ8(weights,filter_offset +\n"
" icb*filter_is_pitch +\n"
" kh*filter_y_pitch +\n"
" kw*filter_x_pitch +\n"
" 8*filter_isv_pitch));\n"
" const COMPUTE_FLOAT2 src0=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,0));\n"
" const COMPUTE_FLOAT2 src1=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,1));\n"
" const COMPUTE_FLOAT2 src2=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,2));\n"
" const COMPUTE_FLOAT2 src3=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,3));\n"
" const COMPUTE_FLOAT2 src4=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,4));\n"
" const COMPUTE_FLOAT2 src5=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,5));\n"
" const COMPUTE_FLOAT2 src6=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,6));\n"
" const COMPUTE_FLOAT2 src7=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,7));\n"
" const COMPUTE_FLOAT2 src8=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,8));\n"
" const COMPUTE_FLOAT2 src9=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,9));\n"
" const COMPUTE_FLOAT2 src10=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,10));\n"
" const COMPUTE_FLOAT2 src11=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,11));\n"
" const COMPUTE_FLOAT2 src12=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,12));\n"
" const COMPUTE_FLOAT2 src13=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,13));\n"
" const COMPUTE_FLOAT2 src14=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,14));\n"
" const COMPUTE_FLOAT2 src15=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,15));\n"
" dst=mad(weight0.s0,src0,dst);\n"
" dst=mad(weight0.s1,src1,dst);\n"
" dst=mad(weight0.s2,src2,dst);\n"
" dst=mad(weight0.s3,src3,dst);\n"
" dst=mad(weight0.s4,src4,dst);\n"
" dst=mad(weight0.s5,src5,dst);\n"
" dst=mad(weight0.s6,src6,dst);\n"
" dst=mad(weight0.s7,src7,dst);\n"
" dst=mad(weight1.s0,src8,dst);\n"
" dst=mad(weight1.s1,src9,dst);\n"
" dst=mad(weight1.s2,src10,dst);\n"
" dst=mad(weight1.s3,src11,dst);\n"
" dst=mad(weight1.s4,src12,dst);\n"
" dst=mad(weight1.s5,src13,dst);\n"
" dst=mad(weight1.s6,src14,dst);\n"
" dst=mad(weight1.s7,src15,dst);\n"
" }\n"
" }\n"
" }\n"
" \n"
"#if SLM_DIV_FACTOR>1\n"
" sum[lid1]=dst;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" if (feature_sub_block == 0) {\n"
" __attribute__((opencl_unroll_hint)) for(int i=1; i<SLM_DIV_FACTOR; i++)\n"
" dst += sum[lid1 % feature_per_wg+i*feature_per_wg];\n"
"#endif\n"
"#ifdef RELU\n"
" dst=fmax(dst,(COMPUTE_FLOAT2)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" dst=clamp(dst,(COMPUTE_FLOAT2)0,(COMPUTE_FLOAT2)6);\n"
"#endif\n"
" const uint lid_x=sglid % 4;\n"
" const uint lid_y=sglid/4;\n"
" if ((feature_block+1)*16 >= output_channel) {\n"
" for (int i=0; i<2 && (x+i)<output_width; i++) {\n"
" if ((feature_block*16+lid_y*4+lid_x<output_channel))\n"
" output[output_offset+lid_y*output_b_pitch+i*output_x_pitch+lid_x]=(FLOAT)dst[i];\n"
" }\n"
" }\n"
" else\n"
" {\n"
" for (int i=0; i<2 && (x+i)<output_width; i++) {\n"
" output[output_offset+lid_y*output_b_pitch+i*output_x_pitch+lid_x]=(FLOAT)dst[i];\n"
" }\n"
" }\n"
"#if SLM_DIV_FACTOR>1\n"
" }\n"
"#endif\n"
"}\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void conv_2d_buf_subgroup_c16_c4_b4(\n"
" __global FLOAT* input,\n"
" __global FLOAT* output,\n"
" __global FLOAT* weights,\n"
" __global FLOAT* biases,\n"
" __private const int pad_width,\n"
" __private const int pad_height,\n"
" __private const int input_width,\n"
" __private const int input_height,\n"
" __private const int output_width,\n"
" __private const int output_height,\n"
" __private const int output_channel,\n"
" __private const int batch,\n"
" __private const int x_blocks,\n"
" __private const int input_pad_left,\n"
" __private const int input_pad_right,\n"
" __private const int output_pad_left,\n"
" __private const int output_pad_right\n"
") {\n"
" const int sglid=get_sub_group_local_id();\n"
" const int b=(uint)get_global_id(2);\n"
" const int xy=get_global_id(0);\n"
" const int x=(xy % x_blocks) << 2;\n"
" const int y=(xy/x_blocks);\n"
" const int lid1=(int)get_local_id(1);\n"
" const int feature_per_wg=(int)get_local_size(1)/SLM_DIV_FACTOR;\n"
" const int feature_sub_block=lid1/feature_per_wg;\n"
" const int feature_block=(int)get_group_id(1);\n"
" const int input_x=x*STRIDE_WIDTH-pad_width;\n"
" const int input_y=y*STRIDE_HEIGHT-pad_height;\n"
" const uint input_x_pitch=16;\n"
" const uint input_y_pitch=input_x_pitch*(input_pad_left+input_width+input_pad_right);\n"
" const uint input_fs_pitch=input_y_pitch*(input_height);\n"
" const uint input_b_pitch=input_fs_pitch*((INPUT_CHANNEL+15)/16);\n"
" const uint input_offset=b*input_b_pitch +\n"
" input_y*input_y_pitch +\n"
" (input_x+input_pad_left)*input_x_pitch;\n"
" const uint output_x_pitch=4;\n"
" const uint output_y_pitch=output_x_pitch*output_width;\n"
" const uint output_fs_pitch=output_y_pitch*output_height;\n"
" const uint output_b_pitch=output_fs_pitch*batch;\n"
" const uint output_offset=b*output_fs_pitch +\n"
" (feature_block << 2)*output_b_pitch +\n"
" y*output_y_pitch +\n"
" x*output_x_pitch;\n"
" const uint filter_isv_pitch=16;\n"
" const uint filter_x_pitch=16*16;\n"
" const uint filter_y_pitch=filter_x_pitch*FILTER_WIDTH;\n"
" const uint filter_is_pitch=filter_y_pitch*FILTER_HEIGHT;\n"
" const uint filter_os_pitch=filter_is_pitch*((INPUT_CHANNEL+15)/16);\n"
" const uint filter_offset=feature_block*filter_os_pitch;\n"
"#if SLM_DIV_FACTOR == 1\n"
" COMPUTE_FLOAT4 dst=(COMPUTE_FLOAT4)((GROUP_READ(biases,feature_block*16)));\n"
"#else\n"
" COMPUTE_FLOAT4 dst;\n"
" if (feature_sub_block == 0) {\n"
" dst=(COMPUTE_FLOAT4)((GROUP_READ(biases,feature_block*16)));\n"
" } else {\n"
" dst=(COMPUTE_FLOAT4)0;\n"
" }\n"
"#endif \n"
"#if SLM_DIV_FACTOR>1\n"
" __local COMPUTE_FLOAT4 sum[WORK_GROUP_SIZE];\n"
"#endif\n"
"#if SLM_DIV_FACTOR>1\n"
" for (int icb=feature_sub_block*IC_BLOCKS/SLM_DIV_FACTOR; icb<(feature_sub_block+1)*IC_BLOCKS/SLM_DIV_FACTOR; icb++) {\n"
"#else\n"
" for (int icb=0; icb<IC_BLOCKS; icb++) {\n"
"#endif \n"
" __attribute__((opencl_unroll_hint(FILTER_HEIGHT)))\n"
" for (int kh=0; kh<FILTER_HEIGHT; kh++) {\n"
" if (input_y+kh*DILATION_HEIGHT<0 || input_y+kh*DILATION_HEIGHT >= input_height)\n"
" continue;\n"
" FLOAT line_cache[INPUT_LINE_SIZE];\n"
" {\n"
" int xb=0;\n"
" for (; xb+8 <= INPUT_LINE_SIZE; xb += 8) {\n"
" COMPUTE_FLOAT8 tmp=CONVERT_COMPUTE_FLOAT8(GROUP_READ8(input,input_offset +\n"
" icb*input_fs_pitch +\n"
" kh*DILATION_HEIGHT*input_y_pitch +\n"
" xb*input_x_pitch));\n"
" \n"
" line_cache[xb+0]=tmp[0];\n"
" line_cache[xb+1]=tmp[1];\n"
" line_cache[xb+2]=tmp[2];\n"
" line_cache[xb+3]=tmp[3];\n"
" line_cache[xb+4]=tmp[4];\n"
" line_cache[xb+5]=tmp[5];\n"
" line_cache[xb+6]=tmp[6];\n"
" line_cache[xb+7]=tmp[7];\n"
" }\n"
" for (; xb+4 <= INPUT_LINE_SIZE; xb += 4) {\n"
" COMPUTE_FLOAT4 tmp=CONVERT_COMPUTE_FLOAT4(GROUP_READ4(input,input_offset +\n"
" icb*input_fs_pitch +\n"
" kh*DILATION_HEIGHT*input_y_pitch +\n"
" xb*input_x_pitch));\n"
" \n"
" line_cache[xb+0]=tmp[0];\n"
" line_cache[xb+1]=tmp[1];\n"
" line_cache[xb+2]=tmp[2];\n"
" line_cache[xb+3]=tmp[3];\n"
" }\n"
" for (; xb<INPUT_LINE_SIZE; xb++) {\n"
" line_cache[xb]=(COMPUTE_FLOAT)GROUP_READ(input,input_offset +\n"
" icb*input_fs_pitch +\n"
" kh*DILATION_HEIGHT*input_y_pitch +\n"
" xb*input_x_pitch);\n"
" }\n"
" }\n"
" __attribute__((opencl_unroll_hint(FILTER_WIDTH)))\n"
" for (int kw=0; kw<FILTER_WIDTH; kw++) {\n"
" FLOAT4 src;\n"
" __attribute__((opencl_unroll_hint(4)))\n"
" for (int i=0; i<4; i++) {\n"
"#if FILTER_WIDTH == 1 && DILATION_WIDTH == 1 && STRIDE_WIDTH == 1\n"
" src[i]=line_cache[i];\n"
"#else\n"
" src[i]=line_cache[kw*DILATION_WIDTH+STRIDE_WIDTH*i];\n"
"#endif\n"
" }\n"
" COMPUTE_FLOAT8 weight0=CONVERT_COMPUTE_FLOAT8(GROUP_READ8(weights,filter_offset +\n"
" icb*filter_is_pitch +\n"
" kh*filter_y_pitch +\n"
" kw*filter_x_pitch));\n"
" COMPUTE_FLOAT8 weight1=CONVERT_COMPUTE_FLOAT8(GROUP_READ8(weights,filter_offset +\n"
" icb*filter_is_pitch +\n"
" kh*filter_y_pitch +\n"
" kw*filter_x_pitch +\n"
" 8*filter_isv_pitch));\n"
" const COMPUTE_FLOAT4 src0=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,0));\n"
" const COMPUTE_FLOAT4 src1=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,1));\n"
" const COMPUTE_FLOAT4 src2=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,2));\n"
" const COMPUTE_FLOAT4 src3=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,3));\n"
" const COMPUTE_FLOAT4 src4=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,4));\n"
" const COMPUTE_FLOAT4 src5=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,5));\n"
" const COMPUTE_FLOAT4 src6=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,6));\n"
" const COMPUTE_FLOAT4 src7=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,7));\n"
" const COMPUTE_FLOAT4 src8=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,8));\n"
" const COMPUTE_FLOAT4 src9=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,9));\n"
" const COMPUTE_FLOAT4 src10=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,10));\n"
" const COMPUTE_FLOAT4 src11=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,11));\n"
" const COMPUTE_FLOAT4 src12=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,12));\n"
" const COMPUTE_FLOAT4 src13=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,13));\n"
" const COMPUTE_FLOAT4 src14=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,14));\n"
" const COMPUTE_FLOAT4 src15=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,15));\n"
" dst=mad(weight0.s0,src0,dst);\n"
" dst=mad(weight0.s1,src1,dst);\n"
" dst=mad(weight0.s2,src2,dst);\n"
" dst=mad(weight0.s3,src3,dst);\n"
" dst=mad(weight0.s4,src4,dst);\n"
" dst=mad(weight0.s5,src5,dst);\n"
" dst=mad(weight0.s6,src6,dst);\n"
" dst=mad(weight0.s7,src7,dst);\n"
" dst=mad(weight1.s0,src8,dst);\n"
" dst=mad(weight1.s1,src9,dst);\n"
" dst=mad(weight1.s2,src10,dst);\n"
" dst=mad(weight1.s3,src11,dst);\n"
" dst=mad(weight1.s4,src12,dst);\n"
" dst=mad(weight1.s5,src13,dst);\n"
" dst=mad(weight1.s6,src14,dst);\n"
" dst=mad(weight1.s7,src15,dst);\n"
" }\n"
" }\n"
" }\n"
" \n"
"#if SLM_DIV_FACTOR>1\n"
" sum[lid1]=dst;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" if (feature_sub_block == 0) {\n"
" __attribute__((opencl_unroll_hint)) for(int i=1; i<SLM_DIV_FACTOR; i++)\n"
" dst += sum[lid1 % feature_per_wg+i*feature_per_wg];\n"
"#endif\n"
"#ifdef RELU\n"
" dst=fmax(dst,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" dst=clamp(dst,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" const uint lid_x=sglid % 4;\n"
" const uint lid_y=sglid/4;\n"
" if ((feature_block+1)*16 >= output_channel) {\n"
" for (int i=0; i<4 && (x+i)<output_width; i++) {\n"
" if ((feature_block*16+lid_y*4+lid_x<output_channel))\n"
" output[output_offset+lid_y*output_b_pitch+i*output_x_pitch+lid_x]=(FLOAT)dst[i];\n"
" }\n"
" }\n"
" else\n"
" {\n"
" for (int i=0; i<4 && (x+i)<output_width; i++) {\n"
" output[output_offset+lid_y*output_b_pitch+i*output_x_pitch+lid_x]=(FLOAT)dst[i];\n"
" }\n"
" }\n"
"#if SLM_DIV_FACTOR>1\n"
" }\n"
"#endif\n"
"}\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void conv_2d_buf_subgroup_c16_c4_b8(\n"
" __global FLOAT* input,\n"
" __global FLOAT* output,\n"
" __global FLOAT* weights,\n"
" __global FLOAT* biases,\n"
" __private const int pad_width,\n"
" __private const int pad_height,\n"
" __private const int input_width,\n"
" __private const int input_height,\n"
" __private const int output_width,\n"
" __private const int output_height,\n"
" __private const int output_channel,\n"
" __private const int batch,\n"
" __private const int x_blocks,\n"
" __private const int input_pad_left,\n"
" __private const int input_pad_right,\n"
" __private const int output_pad_left,\n"
" __private const int output_pad_right\n"
") {\n"
" const int sglid=get_sub_group_local_id();\n"
" const int b=(uint)get_global_id(2);\n"
" const int xy=get_global_id(0);\n"
" const int x=(xy % x_blocks) << 3;\n"
" const int y=(xy/x_blocks);\n"
" const int lid1=(int)get_local_id(1);\n"
" const int feature_per_wg=(int)get_local_size(1)/SLM_DIV_FACTOR;\n"
" const int feature_sub_block=lid1/feature_per_wg;\n"
" const int feature_block=(int)get_group_id(1);\n"
" const int input_x=x*STRIDE_WIDTH-pad_width;\n"
" const int input_y=y*STRIDE_HEIGHT-pad_height;\n"
" const uint input_x_pitch=16;\n"
" const uint input_y_pitch=input_x_pitch*(input_pad_left+input_width+input_pad_right);\n"
" const uint input_fs_pitch=input_y_pitch*(input_height);\n"
" const uint input_b_pitch=input_fs_pitch*((INPUT_CHANNEL+15)/16);\n"
" const uint input_offset=b*input_b_pitch +\n"
" input_y*input_y_pitch +\n"
" (input_x+input_pad_left)*input_x_pitch;\n"
" const uint output_x_pitch=4;\n"
" const uint output_y_pitch=output_x_pitch*output_width;\n"
" const uint output_fs_pitch=output_y_pitch*output_height;\n"
" const uint output_b_pitch=output_fs_pitch*batch;\n"
" const uint output_offset=b*output_fs_pitch +\n"
" (feature_block << 2)*output_b_pitch +\n"
" y*output_y_pitch +\n"
" x*output_x_pitch;\n"
" const uint filter_isv_pitch=16;\n"
" const uint filter_x_pitch=16*16;\n"
" const uint filter_y_pitch=filter_x_pitch*FILTER_WIDTH;\n"
" const uint filter_is_pitch=filter_y_pitch*FILTER_HEIGHT;\n"
" const uint filter_os_pitch=filter_is_pitch*((INPUT_CHANNEL+15)/16);\n"
" const uint filter_offset=feature_block*filter_os_pitch;\n"
"#if SLM_DIV_FACTOR == 1\n"
" COMPUTE_FLOAT8 dst=(COMPUTE_FLOAT8)(GROUP_READ(biases,feature_block*16));\n"
"#else\n"
" COMPUTE_FLOAT8 dst;\n"
" if (feature_sub_block == 0) {\n"
" dst=(COMPUTE_FLOAT8)(GROUP_READ(biases,feature_block*16));\n"
" } else {\n"
" dst=(COMPUTE_FLOAT8)0;\n"
" }\n"
"#endif \n"
"#if SLM_DIV_FACTOR>1\n"
" __local COMPUTE_FLOAT8 sum[WORK_GROUP_SIZE];\n"
"#endif\n"
"#if SLM_DIV_FACTOR>1\n"
" for (int icb=feature_sub_block*IC_BLOCKS/SLM_DIV_FACTOR; icb<(feature_sub_block+1)*IC_BLOCKS/SLM_DIV_FACTOR; icb++) {\n"
"#else\n"
" for (int icb=0; icb<IC_BLOCKS; icb++) {\n"
"#endif \n"
" __attribute__((opencl_unroll_hint(FILTER_HEIGHT)))\n"
" for (int kh=0; kh<FILTER_HEIGHT; kh++) {\n"
" if (input_y+kh*DILATION_HEIGHT<0 || input_y+kh*DILATION_HEIGHT >= input_height)\n"
" continue;\n"
" FLOAT line_cache[INPUT_LINE_SIZE];\n"
" {\n"
" int xb=0;\n"
" for (; xb+8 <= INPUT_LINE_SIZE; xb += 8) {\n"
" COMPUTE_FLOAT8 tmp=CONVERT_COMPUTE_FLOAT8(GROUP_READ8(input,input_offset +\n"
" icb*input_fs_pitch +\n"
" kh*DILATION_HEIGHT*input_y_pitch +\n"
" xb*input_x_pitch));\n"
" \n"
" line_cache[xb+0]=tmp[0];\n"
" line_cache[xb+1]=tmp[1];\n"
" line_cache[xb+2]=tmp[2];\n"
" line_cache[xb+3]=tmp[3];\n"
" line_cache[xb+4]=tmp[4];\n"
" line_cache[xb+5]=tmp[5];\n"
" line_cache[xb+6]=tmp[6];\n"
" line_cache[xb+7]=tmp[7];\n"
" }\n"
" for (; xb+4 <= INPUT_LINE_SIZE; xb += 4) {\n"
" COMPUTE_FLOAT4 tmp=CONVERT_COMPUTE_FLOAT4(GROUP_READ4(input,input_offset +\n"
" icb*input_fs_pitch +\n"
" kh*DILATION_HEIGHT*input_y_pitch +\n"
" xb*input_x_pitch));\n"
" \n"
" line_cache[xb+0]=tmp[0];\n"
" line_cache[xb+1]=tmp[1];\n"
" line_cache[xb+2]=tmp[2];\n"
" line_cache[xb+3]=tmp[3];\n"
" }\n"
" for (; xb<INPUT_LINE_SIZE; xb++) {\n"
" line_cache[xb]=(COMPUTE_FLOAT)GROUP_READ(input,input_offset +\n"
" icb*input_fs_pitch +\n"
" kh*DILATION_HEIGHT*input_y_pitch +\n"
" xb*input_x_pitch);\n"
" }\n"
" }\n"
" __attribute__((opencl_unroll_hint(FILTER_WIDTH)))\n"
" for (int kw=0; kw<FILTER_WIDTH; kw++) {\n"
" FLOAT8 src;\n"
" __attribute__((opencl_unroll_hint(8)))\n"
" for (int i=0; i<8; i++) {\n"
"#if FILTER_WIDTH == 1 && DILATION_WIDTH == 1 && STRIDE_WIDTH == 1\n"
" src[i]=line_cache[i];\n"
"#else\n"
" src[i]=line_cache[kw*DILATION_WIDTH+STRIDE_WIDTH*i];\n"
"#endif\n"
" }\n"
" COMPUTE_FLOAT8 weight0=CONVERT_COMPUTE_FLOAT8(GROUP_READ8(weights,filter_offset +\n"
" icb*filter_is_pitch +\n"
" kh*filter_y_pitch +\n"
" kw*filter_x_pitch));\n"
" COMPUTE_FLOAT8 weight1=CONVERT_COMPUTE_FLOAT8(GROUP_READ8(weights,filter_offset +\n"
" icb*filter_is_pitch +\n"
" kh*filter_y_pitch +\n"
" kw*filter_x_pitch +\n"
" 8*filter_isv_pitch));\n"
" const COMPUTE_FLOAT8 src0=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,0));\n"
" const COMPUTE_FLOAT8 src1=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,1));\n"
" const COMPUTE_FLOAT8 src2=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,2));\n"
" const COMPUTE_FLOAT8 src3=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,3));\n"
" const COMPUTE_FLOAT8 src4=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,4));\n"
" const COMPUTE_FLOAT8 src5=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,5));\n"
" const COMPUTE_FLOAT8 src6=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,6));\n"
" const COMPUTE_FLOAT8 src7=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,7));\n"
" const COMPUTE_FLOAT8 src8=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,8));\n"
" const COMPUTE_FLOAT8 src9=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,9));\n"
" const COMPUTE_FLOAT8 src10=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,10));\n"
" const COMPUTE_FLOAT8 src11=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,11));\n"
" const COMPUTE_FLOAT8 src12=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,12));\n"
" const COMPUTE_FLOAT8 src13=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,13));\n"
" const COMPUTE_FLOAT8 src14=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,14));\n"
" const COMPUTE_FLOAT8 src15=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,15));\n"
" dst=mad(weight0.s0,src0,dst);\n"
" dst=mad(weight0.s1,src1,dst);\n"
" dst=mad(weight0.s2,src2,dst);\n"
" dst=mad(weight0.s3,src3,dst);\n"
" dst=mad(weight0.s4,src4,dst);\n"
" dst=mad(weight0.s5,src5,dst);\n"
" dst=mad(weight0.s6,src6,dst);\n"
" dst=mad(weight0.s7,src7,dst);\n"
" dst=mad(weight1.s0,src8,dst);\n"
" dst=mad(weight1.s1,src9,dst);\n"
" dst=mad(weight1.s2,src10,dst);\n"
" dst=mad(weight1.s3,src11,dst);\n"
" dst=mad(weight1.s4,src12,dst);\n"
" dst=mad(weight1.s5,src13,dst);\n"
" dst=mad(weight1.s6,src14,dst);\n"
" dst=mad(weight1.s7,src15,dst);\n"
" }\n"
" }\n"
" }\n"
" \n"
"#if SLM_DIV_FACTOR>1\n"
" sum[lid1]=dst;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" if (feature_sub_block == 0) {\n"
" __attribute__((opencl_unroll_hint)) for(int i=1; i<SLM_DIV_FACTOR; i++)\n"
" dst += sum[lid1 % feature_per_wg+i*feature_per_wg];\n"
"#endif\n"
"#ifdef RELU\n"
" dst=fmax(dst,(COMPUTE_FLOAT8)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" dst=clamp(dst,(COMPUTE_FLOAT8)0,(COMPUTE_FLOAT8)6);\n"
"#endif\n"
" const uint lid_x=sglid % 4;\n"
" const uint lid_y=sglid/4;\n"
" if ((feature_block+1)*16 >= output_channel) {\n"
" for (int i=0; i<8 && (x+i)<output_width; i++) {\n"
" if ((feature_block*16+lid_y*4+lid_x<output_channel))\n"
" output[output_offset+lid_y*output_b_pitch+i*output_x_pitch+lid_x]=(FLOAT)dst[i];\n"
" }\n"
" }\n"
" else\n"
" {\n"
" for (int i=0; i<8 && (x+i)<output_width; i++) {\n"
" output[output_offset+lid_y*output_b_pitch+i*output_x_pitch+lid_x]=(FLOAT)dst[i];\n"
" }\n"
" }\n"
"#if SLM_DIV_FACTOR>1\n"
" }\n"
"#endif\n"
"}\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void conv_2d_buf_subgroup_c16_c16_b2(\n"
" __global FLOAT* input,\n"
" __global FLOAT* output,\n"
" __global FLOAT* weights,\n"
" __global FLOAT* biases,\n"
" __private const int pad_width,\n"
" __private const int pad_height,\n"
" __private const int input_width,\n"
" __private const int input_height,\n"
" __private const int output_width,\n"
" __private const int output_height,\n"
" __private const int output_channel,\n"
" __private const int batch,\n"
" __private const int x_blocks,\n"
" __private const int input_pad_left,\n"
" __private const int input_pad_right,\n"
" __private const int output_pad_left,\n"
" __private const int output_pad_right\n"
") {\n"
" const int sglid=get_sub_group_local_id();\n"
" const int b=(uint)get_global_id(2);\n"
" const int xy=get_global_id(0);\n"
" const int x=(xy % x_blocks) << 1;\n"
" const int y=(xy/x_blocks);\n"
" const int lid1=(int)get_local_id(1);\n"
" const int feature_per_wg=(int)get_local_size(1)/SLM_DIV_FACTOR;\n"
" const int feature_sub_block=lid1/feature_per_wg;\n"
" const int feature_block=(int)get_group_id(1);\n"
" const int input_x=x*STRIDE_WIDTH-pad_width;\n"
" const int input_y=y*STRIDE_HEIGHT-pad_height;\n"
" const uint input_x_pitch=16;\n"
" const uint input_y_pitch=input_x_pitch*(input_pad_left+input_width+input_pad_right);\n"
" const uint input_fs_pitch=input_y_pitch*(input_height);\n"
" const uint input_b_pitch=input_fs_pitch*((INPUT_CHANNEL+15)/16);\n"
" const uint input_offset=b*input_b_pitch +\n"
" input_y*input_y_pitch +\n"
" (input_x+input_pad_left)*input_x_pitch;\n"
" const uint output_x_pitch=16;\n"
" const uint output_y_pitch=output_x_pitch*(output_pad_left+output_width+output_pad_right);\n"
" const uint output_fs_pitch=output_y_pitch*output_height;\n"
" const uint output_b_pitch=output_fs_pitch*((output_channel+15)/16);\n"
" const uint output_offset=b*output_b_pitch +\n"
" feature_block*output_fs_pitch +\n"
" y*output_y_pitch +\n"
" (x+output_pad_left)*output_x_pitch;\n"
" const uint filter_isv_pitch=16;\n"
" const uint filter_x_pitch=16*16;\n"
" const uint filter_y_pitch=filter_x_pitch*FILTER_WIDTH;\n"
" const uint filter_is_pitch=filter_y_pitch*FILTER_HEIGHT;\n"
" const uint filter_os_pitch=filter_is_pitch*((INPUT_CHANNEL+15)/16);\n"
" const uint filter_offset=feature_block*filter_os_pitch;\n"
"#if SLM_DIV_FACTOR == 1\n"
" COMPUTE_FLOAT2 dst=(COMPUTE_FLOAT2)(GROUP_READ(biases,feature_block*16));\n"
"#else\n"
" COMPUTE_FLOAT2 dst;\n"
" if (feature_sub_block == 0) {\n"
" dst=(COMPUTE_FLOAT2)(GROUP_READ(biases,feature_block*16));\n"
" } else {\n"
" dst=(COMPUTE_FLOAT2)0;\n"
" }\n"
"#endif \n"
"#if SLM_DIV_FACTOR>1\n"
" __local COMPUTE_FLOAT2 sum[WORK_GROUP_SIZE];\n"
"#endif\n"
"#if SLM_DIV_FACTOR>1\n"
" for (int icb=feature_sub_block*IC_BLOCKS/SLM_DIV_FACTOR; icb<(feature_sub_block+1)*IC_BLOCKS/SLM_DIV_FACTOR; icb++) {\n"
"#else\n"
" for (int icb=0; icb<IC_BLOCKS; icb++) {\n"
"#endif \n"
" __attribute__((opencl_unroll_hint(FILTER_HEIGHT)))\n"
" for (int kh=0; kh<FILTER_HEIGHT; kh++) {\n"
" if (input_y+kh*DILATION_HEIGHT<0 || input_y+kh*DILATION_HEIGHT >= input_height)\n"
" continue;\n"
" FLOAT line_cache[INPUT_LINE_SIZE];\n"
" {\n"
" int xb=0;\n"
" for (; xb+8 <= INPUT_LINE_SIZE; xb += 8) {\n"
" COMPUTE_FLOAT8 tmp=CONVERT_COMPUTE_FLOAT8(GROUP_READ8(input,input_offset +\n"
" icb*input_fs_pitch +\n"
" kh*DILATION_HEIGHT*input_y_pitch +\n"
" xb*input_x_pitch));\n"
" \n"
" line_cache[xb+0]=tmp[0];\n"
" line_cache[xb+1]=tmp[1];\n"
" line_cache[xb+2]=tmp[2];\n"
" line_cache[xb+3]=tmp[3];\n"
" line_cache[xb+4]=tmp[4];\n"
" line_cache[xb+5]=tmp[5];\n"
" line_cache[xb+6]=tmp[6];\n"
" line_cache[xb+7]=tmp[7];\n"
" }\n"
" for (; xb+4 <= INPUT_LINE_SIZE; xb += 4) {\n"
" COMPUTE_FLOAT4 tmp=CONVERT_COMPUTE_FLOAT4(GROUP_READ4(input,input_offset +\n"
" icb*input_fs_pitch +\n"
" kh*DILATION_HEIGHT*input_y_pitch +\n"
" xb*input_x_pitch));\n"
" \n"
" line_cache[xb+0]=tmp[0];\n"
" line_cache[xb+1]=tmp[1];\n"
" line_cache[xb+2]=tmp[2];\n"
" line_cache[xb+3]=tmp[3];\n"
" }\n"
" for (; xb<INPUT_LINE_SIZE; xb++) {\n"
" line_cache[xb]=(COMPUTE_FLOAT)GROUP_READ(input,input_offset +\n"
" icb*input_fs_pitch +\n"
" kh*DILATION_HEIGHT*input_y_pitch +\n"
" xb*input_x_pitch);\n"
" }\n"
" }\n"
" __attribute__((opencl_unroll_hint(FILTER_WIDTH)))\n"
" for (int kw=0; kw<FILTER_WIDTH; kw++) {\n"
" FLOAT2 src;\n"
" __attribute__((opencl_unroll_hint(2)))\n"
" for (int i=0; i<2; i++) {\n"
"#if FILTER_WIDTH == 1 && DILATION_WIDTH == 1 && STRIDE_WIDTH == 1\n"
" src[i]=line_cache[i];\n"
"#else\n"
" src[i]=line_cache[kw*DILATION_WIDTH+STRIDE_WIDTH*i];\n"
"#endif\n"
" }\n"
" COMPUTE_FLOAT8 weight0=CONVERT_COMPUTE_FLOAT8(GROUP_READ8(weights,filter_offset +\n"
" icb*filter_is_pitch +\n"
" kh*filter_y_pitch +\n"
" kw*filter_x_pitch));\n"
" COMPUTE_FLOAT8 weight1=CONVERT_COMPUTE_FLOAT8(GROUP_READ8(weights,filter_offset +\n"
" icb*filter_is_pitch +\n"
" kh*filter_y_pitch +\n"
" kw*filter_x_pitch +\n"
" 8*filter_isv_pitch));\n"
" const COMPUTE_FLOAT2 src0=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,0));\n"
" const COMPUTE_FLOAT2 src1=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,1));\n"
" const COMPUTE_FLOAT2 src2=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,2));\n"
" const COMPUTE_FLOAT2 src3=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,3));\n"
" const COMPUTE_FLOAT2 src4=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,4));\n"
" const COMPUTE_FLOAT2 src5=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,5));\n"
" const COMPUTE_FLOAT2 src6=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,6));\n"
" const COMPUTE_FLOAT2 src7=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,7));\n"
" const COMPUTE_FLOAT2 src8=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,8));\n"
" const COMPUTE_FLOAT2 src9=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,9));\n"
" const COMPUTE_FLOAT2 src10=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,10));\n"
" const COMPUTE_FLOAT2 src11=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,11));\n"
" const COMPUTE_FLOAT2 src12=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,12));\n"
" const COMPUTE_FLOAT2 src13=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,13));\n"
" const COMPUTE_FLOAT2 src14=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,14));\n"
" const COMPUTE_FLOAT2 src15=CONVERT_COMPUTE_FLOAT2(GROUP_SHUFFLE2(src,15));\n"
" dst=mad(weight0.s0,src0,dst);\n"
" dst=mad(weight0.s1,src1,dst);\n"
" dst=mad(weight0.s2,src2,dst);\n"
" dst=mad(weight0.s3,src3,dst);\n"
" dst=mad(weight0.s4,src4,dst);\n"
" dst=mad(weight0.s5,src5,dst);\n"
" dst=mad(weight0.s6,src6,dst);\n"
" dst=mad(weight0.s7,src7,dst);\n"
" dst=mad(weight1.s0,src8,dst);\n"
" dst=mad(weight1.s1,src9,dst);\n"
" dst=mad(weight1.s2,src10,dst);\n"
" dst=mad(weight1.s3,src11,dst);\n"
" dst=mad(weight1.s4,src12,dst);\n"
" dst=mad(weight1.s5,src13,dst);\n"
" dst=mad(weight1.s6,src14,dst);\n"
" dst=mad(weight1.s7,src15,dst);\n"
" }\n"
" }\n"
" }\n"
" if(x == 0){\n"
" uint pad_offset=b*output_b_pitch+feature_block*output_fs_pitch+y*output_y_pitch;\n"
" for(int i=0; i<output_pad_left; ++i){\n"
" output[pad_offset+i*output_x_pitch+sglid]=0;\n"
" }\n"
" pad_offset += (output_width+output_pad_left)*output_x_pitch;\n"
" for(int i=0; i<output_pad_right; ++i){\n"
" output[pad_offset+i*output_x_pitch+sglid]=0;\n"
" }\n"
" }\n"
" \n"
"#if SLM_DIV_FACTOR>1\n"
" sum[lid1]=dst;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" if (feature_sub_block == 0) {\n"
" __attribute__((opencl_unroll_hint)) for(int i=1; i<SLM_DIV_FACTOR; i++)\n"
" dst += sum[lid1 % feature_per_wg+i*feature_per_wg];\n"
"#endif\n"
"#ifdef RELU\n"
" dst=fmax(dst,(COMPUTE_FLOAT2)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" dst=clamp(dst,(COMPUTE_FLOAT2)0,(COMPUTE_FLOAT2)6);\n"
"#endif\n"
" if ((feature_block+1)*16 >= output_channel) {\n"
" for (int i=0; i<2; i++) {\n"
" if ((feature_block*16+sglid<output_channel) && (x+i)<output_width)\n"
" output[output_offset+i*output_x_pitch+sglid]=(FLOAT)dst[i];\n"
" }\n"
" }\n"
" else\n"
" {\n"
" if (x+2 <= output_width || output_width % 2 == 0) {\n"
" GROUP_WRITE2(output,output_offset,CONVERT_FLOAT2(dst));\n"
" }else{\n"
" for (int i=0; i<output_width % 2; i++) {\n"
" output[output_offset+i*output_x_pitch+sglid]=(FLOAT)dst[i];\n"
" }\n"
" }\n"
" }\n"
"#if SLM_DIV_FACTOR>1\n"
" }\n"
"#endif\n"
"}\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void conv_2d_buf_subgroup_c16_c16_b4(\n"
" __global FLOAT* input,\n"
" __global FLOAT* output,\n"
" __global FLOAT* weights,\n"
" __global FLOAT* biases,\n"
" __private const int pad_width,\n"
" __private const int pad_height,\n"
" __private const int input_width,\n"
" __private const int input_height,\n"
" __private const int output_width,\n"
" __private const int output_height,\n"
" __private const int output_channel,\n"
" __private const int batch,\n"
" __private const int x_blocks,\n"
" __private const int input_pad_left,\n"
" __private const int input_pad_right,\n"
" __private const int output_pad_left,\n"
" __private const int output_pad_right\n"
") {\n"
" const int sglid=get_sub_group_local_id();\n"
" const int b=(uint)get_global_id(2);\n"
" const int xy=get_global_id(0);\n"
" const int x=(xy % x_blocks) << 2;\n"
" const int y=(xy/x_blocks);\n"
" const int lid1=(int)get_local_id(1);\n"
" const int feature_per_wg=(int)get_local_size(1)/SLM_DIV_FACTOR;\n"
" const int feature_sub_block=lid1/feature_per_wg;\n"
" const int feature_block=(int)get_group_id(1);\n"
" const int input_x=x*STRIDE_WIDTH-pad_width;\n"
" const int input_y=y*STRIDE_HEIGHT-pad_height;\n"
" const uint input_x_pitch=16;\n"
" const uint input_y_pitch=input_x_pitch*(input_pad_left+input_width+input_pad_right);\n"
" const uint input_fs_pitch=input_y_pitch*(input_height);\n"
" const uint input_b_pitch=input_fs_pitch*((INPUT_CHANNEL+15)/16);\n"
" const uint input_offset=b*input_b_pitch +\n"
" input_y*input_y_pitch +\n"
" (input_x+input_pad_left)*input_x_pitch;\n"
" const uint output_x_pitch=16;\n"
" const uint output_y_pitch=output_x_pitch*(output_pad_left+output_width+output_pad_right);\n"
" const uint output_fs_pitch=output_y_pitch*output_height;\n"
" const uint output_b_pitch=output_fs_pitch*((output_channel+15)/16);\n"
" const uint output_offset=b*output_b_pitch +\n"
" feature_block*output_fs_pitch +\n"
" y*output_y_pitch +\n"
" (x+output_pad_left)*output_x_pitch;\n"
" const uint filter_isv_pitch=16;\n"
" const uint filter_x_pitch=16*16;\n"
" const uint filter_y_pitch=filter_x_pitch*FILTER_WIDTH;\n"
" const uint filter_is_pitch=filter_y_pitch*FILTER_HEIGHT;\n"
" const uint filter_os_pitch=filter_is_pitch*((INPUT_CHANNEL+15)/16);\n"
" const uint filter_offset=feature_block*filter_os_pitch;\n"
"#if SLM_DIV_FACTOR == 1\n"
" COMPUTE_FLOAT4 dst=(COMPUTE_FLOAT4)(GROUP_READ(biases,feature_block*16));\n"
"#else\n"
" COMPUTE_FLOAT4 dst;\n"
" if (feature_sub_block == 0) {\n"
" dst=(COMPUTE_FLOAT4)(GROUP_READ(biases,feature_block*16));\n"
" } else {\n"
" dst=(COMPUTE_FLOAT4)0;\n"
" }\n"
"#endif \n"
"#if SLM_DIV_FACTOR>1\n"
" __local COMPUTE_FLOAT4 sum[WORK_GROUP_SIZE];\n"
"#endif\n"
"#if SLM_DIV_FACTOR>1\n"
" for (int icb=feature_sub_block*IC_BLOCKS/SLM_DIV_FACTOR; icb<(feature_sub_block+1)*IC_BLOCKS/SLM_DIV_FACTOR; icb++) {\n"
"#else\n"
" for (int icb=0; icb<IC_BLOCKS; icb++) {\n"
"#endif \n"
" __attribute__((opencl_unroll_hint(FILTER_HEIGHT)))\n"
" for (int kh=0; kh<FILTER_HEIGHT; kh++) {\n"
" if (input_y+kh*DILATION_HEIGHT<0 || input_y+kh*DILATION_HEIGHT >= input_height)\n"
" continue;\n"
" FLOAT line_cache[INPUT_LINE_SIZE];\n"
" {\n"
" int xb=0;\n"
" for (; xb+8 <= INPUT_LINE_SIZE; xb += 8) {\n"
" COMPUTE_FLOAT8 tmp=CONVERT_COMPUTE_FLOAT8(GROUP_READ8(input,input_offset +\n"
" icb*input_fs_pitch +\n"
" kh*DILATION_HEIGHT*input_y_pitch +\n"
" xb*input_x_pitch));\n"
" \n"
" line_cache[xb+0]=tmp[0];\n"
" line_cache[xb+1]=tmp[1];\n"
" line_cache[xb+2]=tmp[2];\n"
" line_cache[xb+3]=tmp[3];\n"
" line_cache[xb+4]=tmp[4];\n"
" line_cache[xb+5]=tmp[5];\n"
" line_cache[xb+6]=tmp[6];\n"
" line_cache[xb+7]=tmp[7];\n"
" }\n"
" for (; xb+4 <= INPUT_LINE_SIZE; xb += 4) {\n"
" COMPUTE_FLOAT4 tmp=CONVERT_COMPUTE_FLOAT4(GROUP_READ4(input,input_offset +\n"
" icb*input_fs_pitch +\n"
" kh*DILATION_HEIGHT*input_y_pitch +\n"
" xb*input_x_pitch));\n"
" \n"
" line_cache[xb+0]=tmp[0];\n"
" line_cache[xb+1]=tmp[1];\n"
" line_cache[xb+2]=tmp[2];\n"
" line_cache[xb+3]=tmp[3];\n"
" }\n"
" for (; xb<INPUT_LINE_SIZE; xb++) {\n"
" line_cache[xb]=(COMPUTE_FLOAT)GROUP_READ(input,input_offset +\n"
" icb*input_fs_pitch +\n"
" kh*DILATION_HEIGHT*input_y_pitch +\n"
" xb*input_x_pitch);\n"
" }\n"
" }\n"
" __attribute__((opencl_unroll_hint(FILTER_WIDTH)))\n"
" for (int kw=0; kw<FILTER_WIDTH; kw++) {\n"
" FLOAT4 src;\n"
" __attribute__((opencl_unroll_hint(4)))\n"
" for (int i=0; i<4; i++) {\n"
"#if FILTER_WIDTH == 1 && DILATION_WIDTH == 1 && STRIDE_WIDTH == 1\n"
" src[i]=line_cache[i];\n"
"#else\n"
" src[i]=line_cache[kw*DILATION_WIDTH+STRIDE_WIDTH*i];\n"
"#endif\n"
" }\n"
" COMPUTE_FLOAT8 weight0=CONVERT_COMPUTE_FLOAT8(GROUP_READ8(weights,filter_offset +\n"
" icb*filter_is_pitch +\n"
" kh*filter_y_pitch +\n"
" kw*filter_x_pitch));\n"
" COMPUTE_FLOAT8 weight1=CONVERT_COMPUTE_FLOAT8(GROUP_READ8(weights,filter_offset +\n"
" icb*filter_is_pitch +\n"
" kh*filter_y_pitch +\n"
" kw*filter_x_pitch +\n"
" 8*filter_isv_pitch));\n"
" const COMPUTE_FLOAT4 src0=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,0));\n"
" const COMPUTE_FLOAT4 src1=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,1));\n"
" const COMPUTE_FLOAT4 src2=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,2));\n"
" const COMPUTE_FLOAT4 src3=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,3));\n"
" const COMPUTE_FLOAT4 src4=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,4));\n"
" const COMPUTE_FLOAT4 src5=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,5));\n"
" const COMPUTE_FLOAT4 src6=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,6));\n"
" const COMPUTE_FLOAT4 src7=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,7));\n"
" const COMPUTE_FLOAT4 src8=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,8));\n"
" const COMPUTE_FLOAT4 src9=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,9));\n"
" const COMPUTE_FLOAT4 src10=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,10));\n"
" const COMPUTE_FLOAT4 src11=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,11));\n"
" const COMPUTE_FLOAT4 src12=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,12));\n"
" const COMPUTE_FLOAT4 src13=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,13));\n"
" const COMPUTE_FLOAT4 src14=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,14));\n"
" const COMPUTE_FLOAT4 src15=CONVERT_COMPUTE_FLOAT4(GROUP_SHUFFLE4(src,15));\n"
" dst=mad(weight0.s0,src0,dst);\n"
" dst=mad(weight0.s1,src1,dst);\n"
" dst=mad(weight0.s2,src2,dst);\n"
" dst=mad(weight0.s3,src3,dst);\n"
" dst=mad(weight0.s4,src4,dst);\n"
" dst=mad(weight0.s5,src5,dst);\n"
" dst=mad(weight0.s6,src6,dst);\n"
" dst=mad(weight0.s7,src7,dst);\n"
" dst=mad(weight1.s0,src8,dst);\n"
" dst=mad(weight1.s1,src9,dst);\n"
" dst=mad(weight1.s2,src10,dst);\n"
" dst=mad(weight1.s3,src11,dst);\n"
" dst=mad(weight1.s4,src12,dst);\n"
" dst=mad(weight1.s5,src13,dst);\n"
" dst=mad(weight1.s6,src14,dst);\n"
" dst=mad(weight1.s7,src15,dst);\n"
" }\n"
" }\n"
" }\n"
" if(x == 0){\n"
" uint pad_offset=b*output_b_pitch+feature_block*output_fs_pitch+y*output_y_pitch;\n"
" for(int i=0; i<output_pad_left; ++i){\n"
" output[pad_offset+i*output_x_pitch+sglid]=0;\n"
" }\n"
" pad_offset += (output_width+output_pad_left)*output_x_pitch;\n"
" for(int i=0; i<output_pad_right; ++i){\n"
" output[pad_offset+i*output_x_pitch+sglid]=0;\n"
" }\n"
" }\n"
"#if SLM_DIV_FACTOR>1\n"
" sum[lid1]=dst;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" if (feature_sub_block == 0) {\n"
" __attribute__((opencl_unroll_hint)) for(int i=1; i<SLM_DIV_FACTOR; i++)\n"
" dst += sum[lid1 % feature_per_wg+i*feature_per_wg];\n"
"#endif\n"
"#ifdef RELU\n"
" dst=fmax(dst,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" dst=clamp(dst,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" if ((feature_block+1)*16 >= output_channel) {\n"
" for (int i=0; i<4; i++) {\n"
" if ((feature_block*16+sglid<output_channel) && (x+i)<output_width)\n"
" output[output_offset+i*output_x_pitch+sglid]=(FLOAT)dst[i];\n"
" }\n"
" }\n"
" else\n"
" {\n"
" if (x+4 <= output_width || output_width % 4 == 0) {\n"
" GROUP_WRITE4(output,output_offset,CONVERT_FLOAT4(dst));\n"
" }else{\n"
" for (int i=0; i<output_width % 4; i++) {\n"
" output[output_offset+i*output_x_pitch+sglid]=(FLOAT)dst[i];\n"
" }\n"
" }\n"
" }\n"
"#if SLM_DIV_FACTOR>1\n"
" }\n"
"#endif\n"
"}\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void conv_2d_buf_subgroup_c16_c16_b8(\n"
" __global FLOAT* input,\n"
" __global FLOAT* output,\n"
" __global FLOAT* weights,\n"
" __global FLOAT* biases,\n"
" __private const int pad_width,\n"
" __private const int pad_height,\n"
" __private const int input_width,\n"
" __private const int input_height,\n"
" __private const int output_width,\n"
" __private const int output_height,\n"
" __private const int output_channel,\n"
" __private const int batch,\n"
" __private const int x_blocks,\n"
" __private const int input_pad_left,\n"
" __private const int input_pad_right,\n"
" __private const int output_pad_left,\n"
" __private const int output_pad_right\n"
") {\n"
" const int sglid=get_sub_group_local_id();\n"
" const int b=(uint)get_global_id(2);\n"
" const int xy=get_global_id(0);\n"
" const int x=(xy % x_blocks) << 3;\n"
" const int y=(xy/x_blocks);\n"
" const int lid1=(int)get_local_id(1);\n"
" const int feature_per_wg=(int)get_local_size(1)/SLM_DIV_FACTOR;\n"
" const int feature_sub_block=lid1/feature_per_wg;\n"
" const int feature_block=(int)get_group_id(1);\n"
" const int input_x=x*STRIDE_WIDTH-pad_width;\n"
" const int input_y=y*STRIDE_HEIGHT-pad_height;\n"
" const uint input_x_pitch=16;\n"
" const uint input_y_pitch=input_x_pitch*(input_pad_left+input_width+input_pad_right);\n"
" const uint input_fs_pitch=input_y_pitch*(input_height);\n"
" const uint input_b_pitch=input_fs_pitch*((INPUT_CHANNEL+15)/16);\n"
" const uint input_offset=b*input_b_pitch +\n"
" input_y*input_y_pitch +\n"
" (input_x+input_pad_left)*input_x_pitch;\n"
" const uint output_x_pitch=16;\n"
" const uint output_y_pitch=output_x_pitch*(output_pad_left+output_width+output_pad_right);\n"
" const uint output_fs_pitch=output_y_pitch*output_height;\n"
" const uint output_b_pitch=output_fs_pitch*((output_channel+15)/16);\n"
" const uint output_offset=b*output_b_pitch +\n"
" feature_block*output_fs_pitch +\n"
" y*output_y_pitch +\n"
" (x+output_pad_left)*output_x_pitch;\n"
" const uint filter_isv_pitch=16;\n"
" const uint filter_x_pitch=16*16;\n"
" const uint filter_y_pitch=filter_x_pitch*FILTER_WIDTH;\n"
" const uint filter_is_pitch=filter_y_pitch*FILTER_HEIGHT;\n"
" const uint filter_os_pitch=filter_is_pitch*((INPUT_CHANNEL+15)/16);\n"
" const uint filter_offset=feature_block*filter_os_pitch;\n"
"#if SLM_DIV_FACTOR == 1\n"
" COMPUTE_FLOAT8 dst=(COMPUTE_FLOAT8)(GROUP_READ(biases,feature_block*16));\n"
"#else\n"
" COMPUTE_FLOAT8 dst;\n"
" if (feature_sub_block == 0) {\n"
" dst=(COMPUTE_FLOAT8)(GROUP_READ(biases,feature_block*16));\n"
" } else {\n"
" dst=(COMPUTE_FLOAT8)0;\n"
" }\n"
"#endif \n"
"#if SLM_DIV_FACTOR>1\n"
" __local COMPUTE_FLOAT8 sum[WORK_GROUP_SIZE];\n"
"#endif\n"
"#if SLM_DIV_FACTOR>1\n"
" for (int icb=feature_sub_block*IC_BLOCKS/SLM_DIV_FACTOR; icb<(feature_sub_block+1)*IC_BLOCKS/SLM_DIV_FACTOR; icb++) {\n"
"#else\n"
" for (int icb=0; icb<IC_BLOCKS; icb++) {\n"
"#endif \n"
" __attribute__((opencl_unroll_hint(FILTER_HEIGHT)))\n"
" for (int kh=0; kh<FILTER_HEIGHT; kh++) {\n"
" if (input_y+kh*DILATION_HEIGHT<0 || input_y+kh*DILATION_HEIGHT >= input_height)\n"
" continue;\n"
" FLOAT line_cache[INPUT_LINE_SIZE];\n"
" {\n"
" int xb=0;\n"
" for (; xb+8 <= INPUT_LINE_SIZE; xb += 8) {\n"
" COMPUTE_FLOAT8 tmp=CONVERT_COMPUTE_FLOAT8(GROUP_READ8(input,input_offset +\n"
" icb*input_fs_pitch +\n"
" kh*DILATION_HEIGHT*input_y_pitch +\n"
" xb*input_x_pitch));\n"
" \n"
" line_cache[xb+0]=tmp[0];\n"
" line_cache[xb+1]=tmp[1];\n"
" line_cache[xb+2]=tmp[2];\n"
" line_cache[xb+3]=tmp[3];\n"
" line_cache[xb+4]=tmp[4];\n"
" line_cache[xb+5]=tmp[5];\n"
" line_cache[xb+6]=tmp[6];\n"
" line_cache[xb+7]=tmp[7];\n"
" }\n"
" for (; xb+4 <= INPUT_LINE_SIZE; xb += 4) {\n"
" COMPUTE_FLOAT4 tmp=CONVERT_COMPUTE_FLOAT4(GROUP_READ4(input,input_offset +\n"
" icb*input_fs_pitch +\n"
" kh*DILATION_HEIGHT*input_y_pitch +\n"
" xb*input_x_pitch));\n"
" \n"
" line_cache[xb+0]=tmp[0];\n"
" line_cache[xb+1]=tmp[1];\n"
" line_cache[xb+2]=tmp[2];\n"
" line_cache[xb+3]=tmp[3];\n"
" }\n"
" for (; xb<INPUT_LINE_SIZE; xb++) {\n"
" line_cache[xb]=(COMPUTE_FLOAT)GROUP_READ(input,input_offset +\n"
" icb*input_fs_pitch +\n"
" kh*DILATION_HEIGHT*input_y_pitch +\n"
" xb*input_x_pitch);\n"
" }\n"
" }\n"
" __attribute__((opencl_unroll_hint(FILTER_WIDTH)))\n"
" for (int kw=0; kw<FILTER_WIDTH; kw++) {\n"
" FLOAT8 src;\n"
" __attribute__((opencl_unroll_hint(8)))\n"
" for (int i=0; i<8; i++) {\n"
"#if FILTER_WIDTH == 1 && DILATION_WIDTH == 1 && STRIDE_WIDTH == 1\n"
" src[i]=line_cache[i];\n"
"#else\n"
" src[i]=line_cache[kw*DILATION_WIDTH+STRIDE_WIDTH*i];\n"
"#endif\n"
" }\n"
" COMPUTE_FLOAT8 weight0=CONVERT_COMPUTE_FLOAT8(GROUP_READ8(weights,filter_offset +\n"
" icb*filter_is_pitch +\n"
" kh*filter_y_pitch +\n"
" kw*filter_x_pitch));\n"
" COMPUTE_FLOAT8 weight1=CONVERT_COMPUTE_FLOAT8(GROUP_READ8(weights,filter_offset +\n"
" icb*filter_is_pitch +\n"
" kh*filter_y_pitch +\n"
" kw*filter_x_pitch +\n"
" 8*filter_isv_pitch));\n"
" const COMPUTE_FLOAT8 src0=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,0));\n"
" const COMPUTE_FLOAT8 src1=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,1));\n"
" const COMPUTE_FLOAT8 src2=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,2));\n"
" const COMPUTE_FLOAT8 src3=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,3));\n"
" const COMPUTE_FLOAT8 src4=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,4));\n"
" const COMPUTE_FLOAT8 src5=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,5));\n"
" const COMPUTE_FLOAT8 src6=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,6));\n"
" const COMPUTE_FLOAT8 src7=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,7));\n"
" const COMPUTE_FLOAT8 src8=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,8));\n"
" const COMPUTE_FLOAT8 src9=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,9));\n"
" const COMPUTE_FLOAT8 src10=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,10));\n"
" const COMPUTE_FLOAT8 src11=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,11));\n"
" const COMPUTE_FLOAT8 src12=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,12));\n"
" const COMPUTE_FLOAT8 src13=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,13));\n"
" const COMPUTE_FLOAT8 src14=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,14));\n"
" const COMPUTE_FLOAT8 src15=CONVERT_COMPUTE_FLOAT8(GROUP_SHUFFLE8(src,15));\n"
" dst=mad(weight0.s0,src0,dst);\n"
" dst=mad(weight0.s1,src1,dst);\n"
" dst=mad(weight0.s2,src2,dst);\n"
" dst=mad(weight0.s3,src3,dst);\n"
" dst=mad(weight0.s4,src4,dst);\n"
" dst=mad(weight0.s5,src5,dst);\n"
" dst=mad(weight0.s6,src6,dst);\n"
" dst=mad(weight0.s7,src7,dst);\n"
" dst=mad(weight1.s0,src8,dst);\n"
" dst=mad(weight1.s1,src9,dst);\n"
" dst=mad(weight1.s2,src10,dst);\n"
" dst=mad(weight1.s3,src11,dst);\n"
" dst=mad(weight1.s4,src12,dst);\n"
" dst=mad(weight1.s5,src13,dst);\n"
" dst=mad(weight1.s6,src14,dst);\n"
" dst=mad(weight1.s7,src15,dst);\n"
" }\n"
" }\n"
" }\n"
" \n"
" \n"
" if(x == 0){\n"
" uint pad_offset=b*output_b_pitch+feature_block*output_fs_pitch+y*output_y_pitch;\n"
" for(int i=0; i<output_pad_left; ++i){\n"
" output[pad_offset+i*output_x_pitch+sglid]=0;\n"
" }\n"
" pad_offset += (output_width+output_pad_left)*output_x_pitch;\n"
" for(int i=0; i<output_pad_right; ++i){\n"
" output[pad_offset+i*output_x_pitch+sglid]=0;\n"
" }\n"
" }\n"
"#if SLM_DIV_FACTOR>1\n"
" sum[lid1]=dst;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" if (feature_sub_block == 0) {\n"
" __attribute__((opencl_unroll_hint)) for(int i=1; i<SLM_DIV_FACTOR; i++)\n"
" dst += sum[lid1 % feature_per_wg+i*feature_per_wg];\n"
"#endif\n"
"#ifdef RELU\n"
" dst=fmax(dst,(COMPUTE_FLOAT8)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" dst=clamp(dst,(COMPUTE_FLOAT8)0,(COMPUTE_FLOAT8)6);\n"
"#endif\n"
" if ((feature_block+1)*16 >= output_channel) {\n"
" for (int i=0; i<8; i++) {\n"
" if ((feature_block*16+sglid<output_channel) && (x+i)<output_width)\n"
" output[output_offset+i*output_x_pitch+sglid]=(FLOAT)dst[i];\n"
" }\n"
" }\n"
" else\n"
" {\n"
" if (x+8 <= output_width || output_width % 8 == 0) {\n"
" GROUP_WRITE8(output,output_offset,CONVERT_FLOAT8(dst));\n"
" }else{\n"
" for (int i=0; i<output_width % 8; i++) {\n"
" output[output_offset+i*output_x_pitch+sglid]=(FLOAT)dst[i];\n"
" }\n"
" }\n"
" }\n"
"#if SLM_DIV_FACTOR>1\n"
" }\n"
"#endif\n"
"}\n"
;
#endif
#endif
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* input_transe_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void conv_transe_c4_c1(\n"
" int global_size_dim0,\n"
" int global_size_dim1,\n"
" int global_size_dim2,\n"
" __global FLOAT* input,\n"
" __global FLOAT* output,\n"
" __private const int input_width,\n"
" __private const int input_height,\n"
" __private const int input_channel,\n"
" __private const int batch,\n"
" __private const int channel_blocks,\n"
" __private const int input_pad_left,\n"
" __private const int input_pad_right)\n"
"{\n"
" int x=get_global_id(0);\n"
" int w=x % input_width;\n"
" int h=x/input_width;\n"
" int c=get_global_id(1);\n"
" int b=get_global_id(2);\n"
" int cout=c << 2;\n"
" if(x >= global_size_dim0 || c >= global_size_dim1 || b >= global_size_dim2)\n"
" return;\n"
" // Input offset calculations:\n"
" const uint input_x_pitch=4;\n"
" const uint input_y_pitch=input_x_pitch*input_width;\n"
" const uint input_f_pitch=input_y_pitch*input_height;\n"
" const uint input_b_pitch=input_f_pitch*batch;\n"
" const uint input_offset=b*input_f_pitch +\n"
" c*input_b_pitch +\n"
" h*input_y_pitch +\n"
" w*input_x_pitch;\n"
" // Output offset calculations:\n"
" const uint output_x_pitch=1;\n"
" const uint output_y_pitch=output_x_pitch*input_width;\n"
" const uint output_f_pitch=output_y_pitch*input_height;\n"
" const uint output_b_pitch=output_f_pitch*input_channel;\n"
" const uint output_offset=b*output_b_pitch +\n"
" cout*output_f_pitch+\n"
" h*output_y_pitch +\n"
" w*output_x_pitch;\n"
" \n"
" FLOAT4 value=vload4(0,input+input_offset);\n"
" for(int i=0; i<4 && cout+i<input_channel; ++i){\n"
" output[output_offset+i*output_f_pitch]=value[i];\n"
" }\n"
"}\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void conv_transe_c4_c16(\n"
" int global_size_dim0,\n"
" int global_size_dim1,\n"
" int global_size_dim2,\n"
" __global FLOAT* input,\n"
" __global FLOAT* output,\n"
" int input_width,\n"
" int input_height,\n"
" int input_channel,\n"
" int batch,\n"
" int channel_blocks,\n"
" int input_pad_left,\n"
" int input_pad_right)\n"
"{\n"
" int x=get_global_id(0);\n"
" int w=x % input_width;\n"
" int h=x/input_width;\n"
" int c=get_global_id(1);\n"
" int b=get_global_id(2);\n"
" int cout=c >> 2;\n"
" if(x >= global_size_dim0 || c >= global_size_dim1 || b >= global_size_dim2)\n"
" return;\n"
" \n"
" // Input offset calculations:\n"
" const uint input_x_pitch=4;\n"
" const uint input_y_pitch=input_x_pitch*input_width;\n"
" const uint input_f_pitch=input_y_pitch*input_height;\n"
" const uint input_b_pitch=input_f_pitch*batch;\n"
" \n"
" const uint input_offset=b*input_f_pitch +\n"
" c*input_b_pitch +\n"
" h*input_y_pitch +\n"
" w*input_x_pitch;\n"
" \n"
" // Output offset calculations:\n"
" const uint output_x_pitch=16;\n"
" const uint output_y_pitch=output_x_pitch*(input_pad_left+input_width+input_pad_right);\n"
" const uint output_f_pitch=output_y_pitch*input_height;\n"
" const uint output_b_pitch=output_f_pitch*((input_channel+15)/16);\n"
" \n"
" const uint output_offset=b*output_b_pitch +\n"
" cout*output_f_pitch+\n"
" h*output_y_pitch +\n"
" (w+input_pad_left)*output_x_pitch+(c % 4)*4;\n"
" \n"
" FLOAT4 value=vload4(0,input+input_offset);\n"
" vstore4(value,0,output+output_offset);\n"
" if(w == 0){\n"
" uint pad_offset=b*output_b_pitch+cout*output_f_pitch+h*output_y_pitch+(c % 4)*4;\n"
" for(int i=0; i<input_pad_left; ++i){\n"
" vstore4((FLOAT4)0,0,output+pad_offset+i*output_x_pitch);\n"
" }\n"
" pad_offset += (input_pad_left+input_width)*output_x_pitch;\n"
" for(int i=0; i<input_pad_right; ++i){\n"
" vstore4((FLOAT4)0,0,output+pad_offset+i*output_x_pitch);\n"
" }\n"
" }\n"
"}\n"
;
#endif
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* reduction_buf = 
"// TODO: use INIT_SCALAR_VALUE,OPERATOR,FINAL_OPERATOR_ON_CHANNEL macro abstract and simplify code\n"
"// TODO: support reduce dims include batch\n"
"// TODO: support keep_dim=False\n"
"// TODO: fix channel reduce result re-pack problem\n"
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_2_DIMS ""__private const int global_size_dim0,__private const int global_size_dim1,\n"
"#define GLOBAL_SIZE_3_DIMS ""__private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"__kernel void reduct_buf(GLOBAL_SIZE_3_DIMS\n"
" __global const INPUT_TYPE *input,\n"
" __global OUTPUT_TYPE *output,\n"
" __private const int inside,\n"
" __private const int outside,\n"
" __private const int dim) {\n"
" const int x=get_global_id(0);\n"
" const int y=get_global_id(1); // inside\n"
" const int z=get_global_id(2); // outside\n"
" DEAL_NON_UNIFORM_DIM3(x,y,z);\n"
" \n"
" INPUT_TYPE out=(INPUT_TYPE)VALUE;\n"
" const int offset=z*dim*inside+y;\n"
" \n"
"#if REDUCT_LOCAL_SIZE>4\n"
" const int lid=get_local_id(0);\n"
" INPUT_TYPE local sum[REDUCT_LOCAL_SIZE];\n"
" for(int i=lid; i<dim; i+=REDUCT_LOCAL_SIZE){\n"
" INPUT_TYPE in=(INPUT_TYPE)input[offset+i*inside];\n"
" out=OPERATE(out,in);\n"
" }\n"
" sum[lid]=out;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=REDUCT_LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=OPERATE(sum[lid],sum[lid+i]);\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" out=sum[0];\n"
"#else\n"
" for(int i=0; i<dim; ++i){\n"
" INPUT_TYPE in=(INPUT_TYPE)input[offset+i*inside];\n"
" out=OPERATE(out,in);\n"
" }\n"
"#endif\n"
"#ifdef GET_AVG\n"
" out=out/dim;\n"
"#endif\n"
" output[z*inside+y]=(OUTPUT_TYPE)out;\n"
"}\n"
"__kernel void reduct_v4_buf(GLOBAL_SIZE_3_DIMS\n"
" __global const INPUT_TYPE *input,\n"
" __global OUTPUT_TYPE *output,\n"
" __private const int inside,\n"
" __private const int outside,\n"
" __private const int dim) {\n"
" const int x=get_global_id(0);\n"
" const int y=get_global_id(1); // inside\n"
" const int z=get_global_id(2); // outside\n"
" DEAL_NON_UNIFORM_DIM3(x,y,z);\n"
" \n"
" INPUT_TYPE4 out=(INPUT_TYPE4)VALUE;\n"
" const int offset=z*dim*inside+(y << 2);\n"
" \n"
"#if REDUCT_LOCAL_SIZE>4\n"
" const int lid=get_local_id(0);\n"
" INPUT_TYPE4 local sum[REDUCT_LOCAL_SIZE];\n"
" for(int i=lid; i<dim; i+=REDUCT_LOCAL_SIZE){\n"
" INPUT_TYPE4 in=vload4(0,input+offset+i*inside);\n"
" out=OPERATE(out,in);\n"
" }\n"
" sum[lid]=out;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=REDUCT_LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=OPERATE(sum[lid],sum[lid+i]);\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" out=sum[0];\n"
"#else\n"
" for(int i=0; i<dim; ++i){\n"
" INPUT_TYPE4 in=vload4(0,input+offset+i*inside);\n"
" out=OPERATE(out,in);\n"
" }\n"
"#endif\n"
"#ifdef GET_AVG\n"
" out=out/(INPUT_TYPE4)dim;\n"
"#endif\n"
" vstore4(CONVERT_OUTPUT4(out),0,output+z*inside+(y << 2));\n"
"}\n"
;
#endif
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* strassen_binary_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"__kernel void binary_cfunction_buf(__private int global_dim0,__private int global_dim1,\n"
" __global FLOAT* input0,\n"
" __private const int offsetC,\n"
" __private const int strideC,\n"
" __global FLOAT* input1,__global FLOAT* output,\n"
" __private const int width,//[offsetA,offsetB,offsetC,0]\n"
" __private const int height//[strideA,strideB,strideC,0]\n"
") {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1));// [X/16,Y]\n"
" \n"
" if (pos.x<global_dim0 && pos.y<global_dim1) {\n"
" int offset_11=offsetC+pos.x*8+pos.y*strideC;\n"
" int offset_12=offset_11+width;\n"
" int offset_21=offset_11+strideC*height;\n"
" int offset_22=offset_21+width;\n"
" FLOAT8 in_11=vload8(0,input0+offset_11);\n"
" FLOAT8 in_12=vload8(0,input0+offset_12);\n"
" FLOAT8 in_21=vload8(0,input0+offset_21);\n"
" FLOAT8 in_22=vload8(0,input0+offset_22);\n"
" FLOAT8 in_cx=vload8(0,input1+pos.x*8+pos.y*width);\n"
" in_12=in_12+in_cx;\n"
" in_21=in_12+in_21;\n"
" in_12=in_22+in_12;\n"
" in_22=in_22+in_21;\n"
" in_12=in_11+in_12;\n"
" vstore8(in_21,0,output+offset_21);\n"
" vstore8(in_22,0,output+offset_22);\n"
" vstore8(in_12,0,output+offset_12);\n"
" }\n"
"}\n"
"#ifndef OPERATOR\n"
"#define OPERATOR in0+in1\n"
"#endif\n"
"__kernel void binary_function_buf(__private int global_dim0,__private int global_dim1,\n"
" __global FLOAT* input0,__global FLOAT* input1,__global FLOAT* output,\n"
" __private const int4 baseOffsets,//[offsetA,offsetB,offsetC,0]\n"
" __private const int4 strides//[strideA,strideB,strideC,0]\n"
") {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1));// [X/16,Y]\n"
" \n"
" if (pos.x<global_dim0 && pos.y<global_dim1) {\n"
" const int baseOffsetA=baseOffsets.x;\n"
" const int baseOffsetB=baseOffsets.y;\n"
" const int baseOffsetC=baseOffsets.z;\n"
" const int strideA=strides.x;\n"
" const int strideB=strides.y;\n"
" const int strideC=strides.z;\n"
" \n"
" \n"
" int offsetA=pos.x*8+pos.y*VEC_H*strideA+baseOffsetA;\n"
" int offsetB=pos.x*8+pos.y*VEC_H*strideB+baseOffsetB;\n"
" int offsetC=pos.x*8+pos.y*VEC_H*strideC+baseOffsetC;\n"
" {\n"
" FLOAT8 in0=vload8(0,input0+offsetA);\n"
" FLOAT8 in1=vload8(0,input1+offsetB);\n"
" FLOAT8 out=OPERATOR;\n"
" vstore8(out,0,output+offsetC);\n"
" }\n"
" #if VEC_H >= 2\n"
" {\n"
" offsetA += strideA;\n"
" offsetB += strideB;\n"
" offsetC += strideC;\n"
" FLOAT8 in0=vload8(0,input0+offsetA);\n"
" FLOAT8 in1=vload8(0,input1+offsetB);\n"
" FLOAT8 out=OPERATOR;\n"
" vstore8(out,0,output+offsetC);\n"
" }\n"
" #endif\n"
" #if VEC_H == 4\n"
" {\n"
" offsetA += strideA;\n"
" offsetB += strideB;\n"
" offsetC += strideC;\n"
" FLOAT8 in0=vload8(0,input0+offsetA);\n"
" FLOAT8 in1=vload8(0,input1+offsetB);\n"
" FLOAT8 out=OPERATOR;\n"
" vstore8(out,0,output+offsetC);\n"
" }\n"
" {\n"
" offsetA += strideA;\n"
" offsetB += strideB;\n"
" offsetC += strideC;\n"
" FLOAT8 in0=vload8(0,input0+offsetA);\n"
" FLOAT8 in1=vload8(0,input1+offsetB);\n"
" FLOAT8 out=OPERATOR;\n"
" vstore8(out,0,output+offsetC);\n"
" }\n"
" #endif\n"
" }\n"
"}\n"
;
#endif
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* matmul_params_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"// =================================================================================================\n"
"#define USE_INLINE_KEYWORD 1\n"
"#ifndef MWG\n"
" #define MWG 8 // Tile-size in dimension M (e.g. 64,128)\n"
"#endif\n"
"#ifndef NWG\n"
" #define NWG 8 // Tile-size in dimension N (e.g. 64,128)\n"
"#endif\n"
"#ifndef KWG\n"
" #define KWG 16 // Tile-size in dimension K (e.g. 8,16)\n"
"#endif\n"
"#ifndef MDIMC\n"
" #define MDIMC 8 // Threads per workgroup in M-dimension (e.g. 8,16,32)\n"
"#endif\n"
"#ifndef NDIMC\n"
" #define NDIMC 8 // Threads per workgroup in N-dimension (e.g. 8,16,32)\n"
"#endif\n"
"#ifndef MDIMA\n"
" #define MDIMA 8 // Re-shaped tile dimension of matrix A: KDIMA*MDIMA (kernel 0 only)\n"
"#endif\n"
"#ifndef NDIMB\n"
" #define NDIMB 8 // Re-shaped tile dimension of matrix B: KDIMB*NDIMB (kernel 0 only)\n"
"#endif\n"
"#ifndef KWI\n"
" #define KWI 2 // Unroll factor of the KWG loop (smaller or equal than KWG)\n"
"#endif\n"
"#ifndef VWM\n"
" #define VWM 1 // Vector width of matrices A and C\n"
"#endif\n"
"#ifndef VWN\n"
" #define VWN 1 // Vector width of matrix B\n"
"#endif\n"
"#ifndef STRM\n"
" #define STRM 0 // Use strided access within a thread in the M-dimension (1) or not (0) (kernel 0 only)\n"
"#endif\n"
"#ifndef STRN\n"
" #define STRN 0 // Use strided access within a thread in the N-dimension (1) or not (0) (kernel 0 only)\n"
"#endif\n"
"#ifndef SA\n"
" #define SA 0 // Use local/shared memory to cache matrix A (1) or not (0) (kernel 0 only)\n"
"#endif\n"
"#ifndef SB\n"
" #define SB 0 // Use local/shared memory to cache matrix B (1) or not (0) (kernel 0 only)\n"
"#endif\n"
"// Helper parameters based on the above tuning parameters\n"
"#define MWI (MWG/MDIMC) // Work per work-item (M-dimension)\n"
"#define NWI (NWG/NDIMC) // Work per work-item (N-dimension)\n"
"#define KDIMA ((MDIMC*NDIMC)/(MDIMA)) // Re-shaped tile dimension of matrix A: KDIMA*MDIMA\n"
"#define KDIMB ((MDIMC*NDIMC)/(NDIMB)) // Re-shaped tile dimension of matrix B: KDIMB*NDIMB\n"
"#define MWA (MWG/MDIMA) // Amount of loads-per-thread for matrix A (M-dimension)\n"
"#define KWA (KWG/KDIMA) // Amount of loads-per-thread for matrix A (K-dimension)\n"
"#define KWB (KWG/KDIMB) // Amount of loads-per-thread for matrix B (K-dimension)\n"
"#define NWB (NWG/NDIMB) // Amount of loads-per-thread for matrix B (N-dimension)\n"
"// Settings\n"
"#ifndef USE_VECTOR_MAD\n"
" #define USE_VECTOR_MAD 0 // Unroll (0) or don't (1) unroll the vector MAD manually\n"
"#endif\n"
"#ifndef GLOBAL_MEM_FENCE\n"
" #define GLOBAL_MEM_FENCE 0 // Global synchronisation barrier for potential better performance\n"
"#endif\n"
"// Pointers to local memory objects (using a define because CUDA doesn't need them)\n"
"#ifndef LOCAL_PTR\n"
" #define LOCAL_PTR __local\n"
"#endif\n"
"// Don't use the non-IEEE754 compliant OpenCL built-in mad() instruction per default. For specific\n"
"// devices,this is enabled (see src/routine.cpp).\n"
"#ifndef USE_CL_MAD\n"
" #define USE_CL_MAD 0\n"
"#endif\n"
"// BIAS_TYPE\n"
"// 0 -> without bias\n"
"// 1 -> with bias (add) [N]\n"
"// 2 -> with bias (eltwise_add) [M,N]\n"
"// 3 -> with bias (eltwise_sub) [M,N]\n"
"// 4 -> with bias (eltwise_sub and get negative) [M,N]\n"
"// 5 -> with bias (mask 0 for invalid) [M,N]\n"
"#ifndef BIAS_TYPE\n"
" #define BIAS_TYPE 0\n"
"#endif\n"
"#if BIAS_TYPE == 1\n"
"#define DEAL_BIAS(x,a) x=x+a\n"
"#elif BIAS_TYPE == 2\n"
"#define DEAL_BIAS(x,a) x=x+a\n"
"#elif BIAS_TYPE == 3\n"
"#define DEAL_BIAS(x,a) x=x-a\n"
"#elif BIAS_TYPE == 4\n"
"#define DEAL_BIAS(x,a) x=a-x\n"
"#elif BIAS_TYPE == 5\n"
"#define DEAL_BIAS(x,a) x=(a == 0 ? (FLOAT)(-FLT_MAX) : x)\n"
"#endif\n"
"// By default the workgroup size requirement is enabled. For Qualcomm devices the workgroup size\n"
"// requirement results in worse performance and is disabled (src/utilities/compile.cpp)\n"
"#ifndef RELAX_WORKGROUP_SIZE\n"
" #define RELAX_WORKGROUP_SIZE 0\n"
"#endif\n"
"typedef float real_arg;\n"
"#define GetRealArg(x) (FLOAT)x\n"
"typedef FLOAT real;\n"
"#ifndef PRECISION_COMPUTE\n"
"#define PRECISION_COMPUTE COMPUTE_FLOAT\n"
"#define CONVERT_PRECISION_COMPUTE(x) CONVERT_COMPUTE_FLOAT(x)\n"
"#endif\n"
"#ifndef PRECISION_COMPUTE2\n"
"#define PRECISION_COMPUTE2 COMPUTE_FLOAT2\n"
"#define CONVERT_PRECISION_COMPUTE2(x) CONVERT_COMPUTE_FLOAT2(x)\n"
"#endif\n"
"#ifndef PRECISION_COMPUTE4\n"
"#define PRECISION_COMPUTE4 COMPUTE_FLOAT4\n"
"#define CONVERT_PRECISION_COMPUTE4(x) CONVERT_COMPUTE_FLOAT4(x)\n"
"#endif\n"
"#ifndef PRECISION_COMPUTE8\n"
"#define PRECISION_COMPUTE8 COMPUTE_FLOAT8\n"
"#define CONVERT_PRECISION_COMPUTE8(x) CONVERT_COMPUTE_FLOAT8(x)\n"
"#endif\n"
"#ifndef PRECISION_COMPUTE16\n"
"#define PRECISION_COMPUTE16 COMPUTE_FLOAT16\n"
"#define CONVERT_PRECISION_COMPUTE16(x) CONVERT_COMPUTE_FLOAT16(x)\n"
"#endif\n"
"#define ZERO (PRECISION_COMPUTE)0.0f\n"
"// Sets a variable to zero\n"
"#define SetToZero(a) a=ZERO\n"
"#define IsZero(a) (a == ZERO)\n"
"#define Multiply(c,a,b) c=a*b\n"
"#if USE_CL_MAD == 1\n"
"#define MultiplyAdd(c,a,b) c=mad(a,b,c)\n"
"#else\n"
"#define MultiplyAdd(c,a,b) c += a*b\n"
"#endif\n"
"#define AXPBY(e,a,b,c,d) e=a*b+c*d\n"
"// Force inlining functions or not: some compilers don't support the inline keyword\n"
"#ifdef USE_INLINE_KEYWORD\n"
" #define INLINE_FUNC inline\n"
"#else\n"
" #define INLINE_FUNC\n"
"#endif\n"
"INLINE_FUNC int GetGroupID1() { return get_group_id(1); }\n"
"INLINE_FUNC int GetGroupID0() { return get_group_id(0); }\n"
"// =================================================================================================\n"
"// Data-widths in dimension M\n"
"#if VWM == 1\n"
" typedef FLOAT realM;\n"
" #define COMPUTE_FLOATM PRECISION_COMPUTE\n"
" #define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE(x)\n"
" #define CONVERT_FLOATM(x) CONVERT_FLOAT(x)\n"
"#elif VWM == 2\n"
" typedef FLOAT2 realM;\n"
" #define COMPUTE_FLOATM PRECISION_COMPUTE2\n"
" #define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE2(x)\n"
" #define CONVERT_FLOATM(x) CONVERT_FLOAT2(x)\n"
"#elif VWM == 4\n"
" typedef FLOAT4 realM;\n"
" #define COMPUTE_FLOATM PRECISION_COMPUTE4\n"
" #define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE4(x)\n"
" #define CONVERT_FLOATM(x) CONVERT_FLOAT4(x)\n"
"#elif VWM == 8\n"
" typedef FLOAT8 realM;\n"
" #define COMPUTE_FLOATM PRECISION_COMPUTE8\n"
" #define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE8(x)\n"
" #define CONVERT_FLOATM(x) CONVERT_FLOAT8(x)\n"
"#elif VWM == 16\n"
" typedef FLOAT16 realM;\n"
" #define COMPUTE_FLOATM PRECISION_COMPUTE16\n"
" #define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE16(x)\n"
" #define CONVERT_FLOATM(x) CONVERT_FLOAT16(x)\n"
"#endif\n"
"// Data-widths in dimension N\n"
"#if VWN == 1\n"
" typedef FLOAT realN;\n"
" typedef int intN;\n"
" #define COMPUTE_FLOATN PRECISION_COMPUTE\n"
" #define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE(x)\n"
" #define CONVERT_FLOATN(x) CONVERT_FLOAT(x)\n"
"#elif VWN == 2\n"
" typedef FLOAT2 realN;\n"
" typedef int2 intN;\n"
" #define COMPUTE_FLOATN PRECISION_COMPUTE2\n"
" #define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE2(x)\n"
" #define CONVERT_FLOATN(x) CONVERT_FLOAT2(x)\n"
"#elif VWN == 4\n"
" typedef FLOAT4 realN;\n"
" typedef int4 intN;\n"
" #define COMPUTE_FLOATN PRECISION_COMPUTE4\n"
" #define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE4(x)\n"
" #define CONVERT_FLOATN(x) CONVERT_FLOAT4(x)\n"
"#elif VWN == 8\n"
" typedef FLOAT8 realN;\n"
" typedef int8 intN;\n"
" #define COMPUTE_FLOATN PRECISION_COMPUTE8\n"
" #define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE8(x)\n"
" #define CONVERT_FLOATN(x) CONVERT_FLOAT8(x)\n"
"#elif VWN == 16\n"
" typedef FLOAT16 realN;\n"
" typedef int16 intN;\n"
" #define COMPUTE_FLOATN PRECISION_COMPUTE16\n"
" #define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE16(x)\n"
" #define CONVERT_FLOATN(x) CONVERT_FLOAT16(x)\n"
"#endif\n"
"// =================================================================================================\n"
"// Initializes the accumulation registers to zero\n"
"INLINE_FUNC COMPUTE_FLOATM InitAccRegisters() {\n"
" COMPUTE_FLOATM result;\n"
" #if VWM == 1\n"
" SetToZero(result);\n"
" #elif VWM == 2\n"
" SetToZero(result.x);\n"
" SetToZero(result.y);\n"
" #elif VWM == 4\n"
" SetToZero(result.x);\n"
" SetToZero(result.y);\n"
" SetToZero(result.z);\n"
" SetToZero(result.w);\n"
" #elif VWM == 8\n"
" SetToZero(result.s0);\n"
" SetToZero(result.s1);\n"
" SetToZero(result.s2);\n"
" SetToZero(result.s3);\n"
" SetToZero(result.s4);\n"
" SetToZero(result.s5);\n"
" SetToZero(result.s6);\n"
" SetToZero(result.s7);\n"
" #elif VWM == 16\n"
" SetToZero(result.s0);\n"
" SetToZero(result.s1);\n"
" SetToZero(result.s2);\n"
" SetToZero(result.s3);\n"
" SetToZero(result.s4);\n"
" SetToZero(result.s5);\n"
" SetToZero(result.s6);\n"
" SetToZero(result.s7);\n"
" SetToZero(result.s8);\n"
" SetToZero(result.s9);\n"
" SetToZero(result.sA);\n"
" SetToZero(result.sB);\n"
" SetToZero(result.sC);\n"
" SetToZero(result.sD);\n"
" SetToZero(result.sE);\n"
" SetToZero(result.sF);\n"
" #endif\n"
" return result;\n"
"}\n"
"INLINE_FUNC COMPUTE_FLOATN InitAccRegistersN() {\n"
" COMPUTE_FLOATN result;\n"
" #if VWN == 1\n"
" SetToZero(result);\n"
" #elif VWN == 2\n"
" SetToZero(result.x);\n"
" SetToZero(result.y);\n"
" #elif VWN == 4\n"
" SetToZero(result.x);\n"
" SetToZero(result.y);\n"
" SetToZero(result.z);\n"
" SetToZero(result.w);\n"
" #elif VWN == 8\n"
" SetToZero(result.s0);\n"
" SetToZero(result.s1);\n"
" SetToZero(result.s2);\n"
" SetToZero(result.s3);\n"
" SetToZero(result.s4);\n"
" SetToZero(result.s5);\n"
" SetToZero(result.s6);\n"
" SetToZero(result.s7);\n"
" #elif VWN == 16\n"
" SetToZero(result.s0);\n"
" SetToZero(result.s1);\n"
" SetToZero(result.s2);\n"
" SetToZero(result.s3);\n"
" SetToZero(result.s4);\n"
" SetToZero(result.s5);\n"
" SetToZero(result.s6);\n"
" SetToZero(result.s7);\n"
" SetToZero(result.s8);\n"
" SetToZero(result.s9);\n"
" SetToZero(result.sA);\n"
" SetToZero(result.sB);\n"
" SetToZero(result.sC);\n"
" SetToZero(result.sD);\n"
" SetToZero(result.sE);\n"
" SetToZero(result.sF);\n"
" #endif\n"
" return result;\n"
"}\n"
"// =================================================================================================\n"
"// Caches global off-chip memory into local (shared) memory on-chip. This function is specific for\n"
"// caching the A input matrix.\n"
"#if SA == 1\n"
"INLINE_FUNC void GlobalToLocalA(const __global realM* restrict agm,LOCAL_PTR realM* alm,\n"
" const int kSizeM,const int tid,const int kwg) {\n"
" const int la0=tid % MDIMA;\n"
" const int la1=tid/MDIMA;\n"
" #pragma unroll\n"
" for (int _mia=0; _mia<MWA/VWM; _mia += 1) {\n"
" #pragma unroll\n"
" for (int _kia=0; _kia<KWA; _kia += 1) {\n"
" // Computes the indices based on strided/non-strided access\n"
" #if STRM == 0\n"
" int mg=_mia+la0*(MWA/VWM);\n"
" #elif STRM == 1\n"
" int mg=la0+_mia*MDIMA;\n"
" #endif\n"
" // Computes the indices for the global memory\n"
" int kg=_kia+la1*KWA;\n"
" int idm=mg+GetGroupID0()*(MWG/VWM);\n"
" int idk=kg+kwg;\n"
" // Loads the data from global memory (not transposed) into the local memory\n"
" alm[kg*(MWG/VWM)+mg]=agm[idk*(kSizeM/VWM)+idm];\n"
" }\n"
" }\n"
"}\n"
"#endif\n"
"// Same as above,but now for the B input matrix\n"
"#if SB == 1\n"
"INLINE_FUNC void GlobalToLocalB(const __global realN* restrict bgm,LOCAL_PTR realN* blm,\n"
" const int kSizeN,const int tid,const int kwg) {\n"
" const int lb0=tid % NDIMB;\n"
" const int lb1=tid/NDIMB;\n"
" #pragma unroll\n"
" for (int _kib=0; _kib<KWB; _kib += 1) {\n"
" #pragma unroll\n"
" for (int _nib=0; _nib<NWB/VWN; _nib += 1) {\n"
" // Computes the indices based on strided/non-strided access\n"
" #if STRN == 0\n"
" int ng=_nib+lb0*(NWB/VWN);\n"
" #elif STRN == 1\n"
" int ng=lb0+_nib*NDIMB;\n"
" #endif\n"
" // Computes the indices for the global memory\n"
" int kg=_kib+lb1*KWB;\n"
" int idn=ng+GetGroupID1()*(NWG/VWN);\n"
" int idk=kg+kwg;\n"
" // Loads the data from global memory (transposed) into the local memory\n"
" blm[kg*(NWG/VWN)+ng]=bgm[idk*(kSizeN/VWN)+idn];\n"
" }\n"
" }\n"
"}\n"
"#endif\n"
"// =================================================================================================\n"
"// Caches global off-chip memory directly into per-thread private memory (registers). This function\n"
"// is specific for caching the A input matrix.\n"
"#if SA == 0\n"
"INLINE_FUNC int GlobalIndexA() {\n"
" // Computes the indices based on strided/non-strided access\n"
" #if STRM == 0\n"
" // [MWG/MWI,MWI/VWM,VWM]\n"
" int mg=get_local_id(0)*(MWI/VWM);\n"
" #elif STRM == 1\n"
" // [MWI/VWM,MWG/MWI,VWM]\n"
" int mg=get_local_id(0);\n"
" #endif\n"
" // Computes the indices for the global memory\n"
" // [kSizeM/MWG,(MWG/VWM),VWM]\n"
" int idm=mg+GetGroupID0()*(MWG/VWM);\n"
" return idm;\n"
"}\n"
"INLINE_FUNC realM GlobalToPrivateOptA(const __global realM* restrict agm,const int base,const int _mi,\n"
" const int astride/*kSizeM*/,const int idk) {\n"
" // Computes the indices based on strided/non-strided access\n"
" #if STRM == 0\n"
" // [MWG/MWI,MWI/VWM,VWM]\n"
" int idm=base+_mi;\n"
" #elif STRM == 1\n"
" // [MWI/VWM,MWG/MWI,VWM]\n"
" int idm=base+_mi*MDIMC;\n"
" #endif\n"
" // Loads the data from global memory (not transposed) and stores into registers\n"
" // [kSizeK,kSizeM/VWM,VWM]\n"
" return agm[idk*(astride/VWM)+idm];\n"
"}\n"
"INLINE_FUNC realM GlobalToPrivateA(const __global realM* restrict agm,const int _mi,\n"
" const int kSizeM,const int idk) {\n"
" // Computes the indices based on strided/non-strided access\n"
" #if STRM == 0\n"
" // [MWG/MWI,MWI/VWM,VWM]\n"
" int mg=_mi+get_local_id(0)*(MWI/VWM);\n"
" #elif STRM == 1\n"
" // [MWI/VWM,MWG/MWI,VWM]\n"
" int mg=get_local_id(0)+_mi*MDIMC;\n"
" #endif\n"
" // Computes the indices for the global memory\n"
" // [kSizeM/MWG,(MWG/VWM),VWM]\n"
" int idm=mg+GetGroupID0()*(MWG/VWM);\n"
" // Loads the data from global memory (not transposed) and stores into registers\n"
" // [kSizeK,kSizeM/VWM,VWM]\n"
" return agm[idk*(kSizeM/VWM)+idm];\n"
"}\n"
"#endif\n"
"// Same as above,but now for the B input matrix\n"
"#if SB == 0\n"
"INLINE_FUNC int GlobalIndexB() {\n"
" // Computes the indices based on strided/non-strided access\n"
" #if STRN == 0\n"
" int ng=get_local_id(1)*(NWI/VWN);\n"
" #elif STRN == 1\n"
" int ng=get_local_id(1);\n"
" #endif\n"
" // Computes the indices for the global memory\n"
" int idn=ng+GetGroupID1()*(NWG/VWN);\n"
" return idn;\n"
"}\n"
"INLINE_FUNC realN GlobalToPrivateOptB(const __global realN* restrict bgm,const int base,const int _ni,\n"
" const int bstride/*kSizeN*/,const int idk) {\n"
" // Computes the indices based on strided/non-strided access\n"
" #if STRN == 0\n"
" int idn=base+_ni;\n"
" #elif STRN == 1\n"
" int idn=base+_ni*NDIMC;\n"
" #endif\n"
" // Loads the data from global memory (transposed) and stores into registers\n"
" return bgm[idk*(bstride/VWN)+idn];\n"
"}\n"
"INLINE_FUNC realN GlobalToPrivateB(const __global realN* restrict bgm,const int _ni,\n"
" const int kSizeN,const int idk) {\n"
" // Computes the indices based on strided/non-strided access\n"
" #if STRN == 0\n"
" int ng=_ni+get_local_id(1)*(NWI/VWN);\n"
" #elif STRN == 1\n"
" int ng=get_local_id(1)+_ni*NDIMC;\n"
" #endif\n"
" // Computes the indices for the global memory\n"
" int idn=ng+GetGroupID1()*(NWG/VWN);\n"
" // Loads the data from global memory (transposed) and stores into registers\n"
" return bgm[idk*(kSizeN/VWN)+idn];\n"
"}\n"
"#endif\n"
"// =================================================================================================\n"
"// Caches on-chip local memory into per-thread private memory (registers). This function is specific\n"
"// for caching the A input matrix.\n"
"#if SA == 1\n"
"INLINE_FUNC realM LocalToPrivateA(LOCAL_PTR realM* alm,const int _mi,const int kg) {\n"
" #if STRM == 0\n"
" int mg=_mi+get_local_id(0)*(MWI/VWM);\n"
" #elif STRM == 1\n"
" int mg=get_local_id(0)+_mi*MDIMC;\n"
" #endif\n"
" return alm[kg*(MWG/VWM)+mg];\n"
"}\n"
"#endif\n"
"// Same as above,but now for the B input matrix\n"
"#if SB == 1\n"
"INLINE_FUNC realN LocalToPrivateB(LOCAL_PTR realN* blm,const int _ni,const int kg) {\n"
" #if STRN == 0\n"
" int ng=_ni+get_local_id(1)*(NWI/VWN);\n"
" #elif STRN == 1\n"
" int ng=get_local_id(1)+_ni*NDIMC;\n"
" #endif\n"
" return blm[kg*(NWG/VWN)+ng];\n"
"}\n"
"#endif\n"
"// The vectorised multiply-add function\n"
"INLINE_FUNC COMPUTE_FLOATM MultiplyAddVector(COMPUTE_FLOATM cvec,COMPUTE_FLOATM avec,PRECISION_COMPUTE bval) {\n"
" #if USE_VECTOR_MAD == 1\n"
" #if USE_CL_MAD == 1\n"
" cvec=mad(avec,(COMPUTE_FLOATM)bval,cvec);\n"
" #else\n"
" cvec += avec*bval;\n"
" #endif\n"
" #else\n"
" #if VWM == 1\n"
" MultiplyAdd(cvec,avec,bval);\n"
" #elif VWM == 2\n"
" MultiplyAdd(cvec.x ,avec.x,bval);\n"
" MultiplyAdd(cvec.y ,avec.y,bval);\n"
" #elif VWM == 4\n"
" MultiplyAdd(cvec.x ,avec.x,bval);\n"
" MultiplyAdd(cvec.y ,avec.y,bval);\n"
" MultiplyAdd(cvec.z ,avec.z,bval);\n"
" MultiplyAdd(cvec.w ,avec.w,bval);\n"
" #elif VWM == 8\n"
" MultiplyAdd(cvec.s0,avec.s0,bval);\n"
" MultiplyAdd(cvec.s1,avec.s1,bval);\n"
" MultiplyAdd(cvec.s2,avec.s2,bval);\n"
" MultiplyAdd(cvec.s3,avec.s3,bval);\n"
" MultiplyAdd(cvec.s4,avec.s4,bval);\n"
" MultiplyAdd(cvec.s5,avec.s5,bval);\n"
" MultiplyAdd(cvec.s6,avec.s6,bval);\n"
" MultiplyAdd(cvec.s7,avec.s7,bval);\n"
" #elif VWM == 16\n"
" MultiplyAdd(cvec.s0,avec.s0,bval);\n"
" MultiplyAdd(cvec.s1,avec.s1,bval);\n"
" MultiplyAdd(cvec.s2,avec.s2,bval);\n"
" MultiplyAdd(cvec.s3,avec.s3,bval);\n"
" MultiplyAdd(cvec.s4,avec.s4,bval);\n"
" MultiplyAdd(cvec.s5,avec.s5,bval);\n"
" MultiplyAdd(cvec.s6,avec.s6,bval);\n"
" MultiplyAdd(cvec.s7,avec.s7,bval);\n"
" MultiplyAdd(cvec.s8,avec.s8,bval);\n"
" MultiplyAdd(cvec.s9,avec.s9,bval);\n"
" MultiplyAdd(cvec.sA,avec.sA,bval);\n"
" MultiplyAdd(cvec.sB,avec.sB,bval);\n"
" MultiplyAdd(cvec.sC,avec.sC,bval);\n"
" MultiplyAdd(cvec.sD,avec.sD,bval);\n"
" MultiplyAdd(cvec.sE,avec.sE,bval);\n"
" MultiplyAdd(cvec.sF,avec.sF,bval);\n"
" #endif\n"
" #endif\n"
" return cvec;\n"
"}\n"
"// The vectorised multiply-add function\n"
"INLINE_FUNC COMPUTE_FLOATN MultiplyAddVectorN(COMPUTE_FLOATN cvec,PRECISION_COMPUTE avec,COMPUTE_FLOATN bval) {\n"
" #if USE_VECTOR_MAD == 1\n"
" #if USE_CL_MAD == 1\n"
" cvec=mad((COMPUTE_FLOATN)avec,bval,cvec);\n"
" #else\n"
" cvec += avec*bval;\n"
" #endif\n"
" #else\n"
" #if VWN == 1\n"
" MultiplyAdd(cvec,avec,bval);\n"
" #elif VWN == 2\n"
" MultiplyAdd(cvec.x ,avec,bval.x);\n"
" MultiplyAdd(cvec.y ,avec,bval.y);\n"
" #elif VWN == 4\n"
" MultiplyAdd(cvec.x ,avec,bval.x);\n"
" MultiplyAdd(cvec.y ,avec,bval.y);\n"
" MultiplyAdd(cvec.z ,avec,bval.z);\n"
" MultiplyAdd(cvec.w ,avec,bval.w);\n"
" #elif VWN == 8\n"
" MultiplyAdd(cvec.s0,avec,bval.s0);\n"
" MultiplyAdd(cvec.s1,avec,bval.s1);\n"
" MultiplyAdd(cvec.s2,avec,bval.s2);\n"
" MultiplyAdd(cvec.s3,avec,bval.s3);\n"
" MultiplyAdd(cvec.s4,avec,bval.s4);\n"
" MultiplyAdd(cvec.s5,avec,bval.s5);\n"
" MultiplyAdd(cvec.s6,avec,bval.s6);\n"
" MultiplyAdd(cvec.s7,avec,bval.s7);\n"
" #elif VWN == 16\n"
" MultiplyAdd(cvec.s0,avec,bval.s0);\n"
" MultiplyAdd(cvec.s1,avec,bval.s1);\n"
" MultiplyAdd(cvec.s2,avec,bval.s2);\n"
" MultiplyAdd(cvec.s3,avec,bval.s3);\n"
" MultiplyAdd(cvec.s4,avec,bval.s4);\n"
" MultiplyAdd(cvec.s5,avec,bval.s5);\n"
" MultiplyAdd(cvec.s6,avec,bval.s6);\n"
" MultiplyAdd(cvec.s7,avec,bval.s7);\n"
" MultiplyAdd(cvec.s8,avec,bval.s8);\n"
" MultiplyAdd(cvec.s9,avec,bval.s9);\n"
" MultiplyAdd(cvec.sA,avec,bval.sA);\n"
" MultiplyAdd(cvec.sB,avec,bval.sB);\n"
" MultiplyAdd(cvec.sC,avec,bval.sC);\n"
" MultiplyAdd(cvec.sD,avec,bval.sD);\n"
" MultiplyAdd(cvec.sE,avec,bval.sE);\n"
" MultiplyAdd(cvec.sF,avec,bval.sF);\n"
" #endif\n"
" #endif\n"
" return cvec;\n"
"}\n"
"// =================================================================================================\n"
"// Merges the results in Cpm with the global array in Cgm. This also performs the multiplication\n"
"// with the constants: Cgm=alpha*A*B+beta*Cgm=alpha*Cpm+beta*Cgm\n"
"typedef struct {\n"
" int index[2];\n"
"} INT2;\n"
"INLINE_FUNC INT2 StoreIndexM() {\n"
" INT2 res;\n"
" #if STRM == 0\n"
" int mg=get_local_id(0)*(MWI/VWM);\n"
" #elif STRM == 1\n"
" int mg=get_local_id(0);\n"
" #endif\n"
" #if STRN == 0\n"
" int ng=get_local_id(1)*NWI;\n"
" #elif STRN == 1\n"
" int ng=get_local_id(1)*VWN;\n"
" #endif\n"
" int idm=mg+GetGroupID0()*(MWG/VWM);\n"
" int idn=ng+GetGroupID1()*NWG;\n"
" res.index[0]=idm;\n"
" res.index[1]=idn;\n"
" return res;\n"
"}\n"
"// layout : [N,M]\n"
"INLINE_FUNC void StoreResultsM(__global realM* cgm,COMPUTE_FLOATM c_value,const INT2 baseOffset,const int _mi,const int _ni,\n"
" const int kSizeM,const PRECISION_COMPUTE alpha,const PRECISION_COMPUTE beta) {\n"
" #if STRM == 0\n"
" int idm=_mi+baseOffset.index[0];\n"
" #elif STRM == 1\n"
" int idm=baseOffset.index[0]+_mi*MDIMC;\n"
" #endif\n"
" #if STRN == 0\n"
" int idn=_ni+baseOffset.index[1];\n"
" #elif STRN == 1\n"
" int idn=_ni%VWN+baseOffset.index[1]+(_ni/VWN)*VWN*NDIMC;\n"
" #endif\n"
" \n"
" int index=idn*(kSizeM/VWM)+idm;\n"
" COMPUTE_FLOATM result=c_value;\n"
" // The final multiplication with alpha (in case beta == 0)\n"
" #ifdef ONLY_HAVE_ALPHA\n"
" COMPUTE_FLOATM xval=c_value;\n"
" #if VWM == 1\n"
" Multiply(result,alpha,xval);\n"
" #elif VWM == 2\n"
" Multiply(result.x,alpha,xval.x);\n"
" Multiply(result.y,alpha,xval.y);\n"
" #elif VWM == 4\n"
" Multiply(result.x,alpha,xval.x);\n"
" Multiply(result.y,alpha,xval.y);\n"
" Multiply(result.z,alpha,xval.z);\n"
" Multiply(result.w,alpha,xval.w);\n"
" #elif VWM == 8\n"
" Multiply(result.s0,alpha,xval.s0);\n"
" Multiply(result.s1,alpha,xval.s1);\n"
" Multiply(result.s2,alpha,xval.s2);\n"
" Multiply(result.s3,alpha,xval.s3);\n"
" Multiply(result.s4,alpha,xval.s4);\n"
" Multiply(result.s5,alpha,xval.s5);\n"
" Multiply(result.s6,alpha,xval.s6);\n"
" Multiply(result.s7,alpha,xval.s7);\n"
" #elif VWM == 16\n"
" Multiply(result.s0,alpha,xval.s0);\n"
" Multiply(result.s1,alpha,xval.s1);\n"
" Multiply(result.s2,alpha,xval.s2);\n"
" Multiply(result.s3,alpha,xval.s3);\n"
" Multiply(result.s4,alpha,xval.s4);\n"
" Multiply(result.s5,alpha,xval.s5);\n"
" Multiply(result.s6,alpha,xval.s6);\n"
" Multiply(result.s7,alpha,xval.s7);\n"
" Multiply(result.s8,alpha,xval.s8);\n"
" Multiply(result.s9,alpha,xval.s9);\n"
" Multiply(result.sA,alpha,xval.sA);\n"
" Multiply(result.sB,alpha,xval.sB);\n"
" Multiply(result.sC,alpha,xval.sC);\n"
" Multiply(result.sD,alpha,xval.sD);\n"
" Multiply(result.sE,alpha,xval.sE);\n"
" Multiply(result.sF,alpha,xval.sF);\n"
" #endif\n"
" #endif\n"
" // The final multiplication with alpha and the addition with beta*C\n"
" #ifdef HAVE_ALPHA_BETA\n"
" COMPUTE_FLOATM xval=c_value;\n"
" COMPUTE_FLOATM yval=CONVERT_COMPUTE_FLOATM(cgm[index]);\n"
" #if VWM == 1\n"
" AXPBY(result,alpha,xval,beta,yval);\n"
" #elif VWM == 2\n"
" AXPBY(result.x,alpha,xval.x,beta,yval.x);\n"
" AXPBY(result.y,alpha,xval.y,beta,yval.y);\n"
" #elif VWM == 4\n"
" AXPBY(result.x,alpha,xval.x,beta,yval.x);\n"
" AXPBY(result.y,alpha,xval.y,beta,yval.y);\n"
" AXPBY(result.z,alpha,xval.z,beta,yval.z);\n"
" AXPBY(result.w,alpha,xval.w,beta,yval.w);\n"
" #elif VWM == 8\n"
" AXPBY(result.s0,alpha,xval.s0,beta,yval.s0);\n"
" AXPBY(result.s1,alpha,xval.s1,beta,yval.s1);\n"
" AXPBY(result.s2,alpha,xval.s2,beta,yval.s2);\n"
" AXPBY(result.s3,alpha,xval.s3,beta,yval.s3);\n"
" AXPBY(result.s4,alpha,xval.s4,beta,yval.s4);\n"
" AXPBY(result.s5,alpha,xval.s5,beta,yval.s5);\n"
" AXPBY(result.s6,alpha,xval.s6,beta,yval.s6);\n"
" AXPBY(result.s7,alpha,xval.s7,beta,yval.s7);\n"
" #elif VWM == 16\n"
" AXPBY(result.s0,alpha,xval.s0,beta,yval.s0);\n"
" AXPBY(result.s1,alpha,xval.s1,beta,yval.s1);\n"
" AXPBY(result.s2,alpha,xval.s2,beta,yval.s2);\n"
" AXPBY(result.s3,alpha,xval.s3,beta,yval.s3);\n"
" AXPBY(result.s4,alpha,xval.s4,beta,yval.s4);\n"
" AXPBY(result.s5,alpha,xval.s5,beta,yval.s5);\n"
" AXPBY(result.s6,alpha,xval.s6,beta,yval.s6);\n"
" AXPBY(result.s7,alpha,xval.s7,beta,yval.s7);\n"
" AXPBY(result.s8,alpha,xval.s8,beta,yval.s8);\n"
" AXPBY(result.s9,alpha,xval.s9,beta,yval.s9);\n"
" AXPBY(result.sA,alpha,xval.sA,beta,yval.sA);\n"
" AXPBY(result.sB,alpha,xval.sB,beta,yval.sB);\n"
" AXPBY(result.sC,alpha,xval.sC,beta,yval.sC);\n"
" AXPBY(result.sD,alpha,xval.sD,beta,yval.sD);\n"
" AXPBY(result.sE,alpha,xval.sE,beta,yval.sE);\n"
" AXPBY(result.sF,alpha,xval.sF,beta,yval.sF);\n"
" #endif\n"
" #endif\n"
" cgm[index]=CONVERT_FLOATM(result);\n"
"}\n"
"INLINE_FUNC INT2 StoreIndexN() {\n"
" INT2 res;\n"
" #if STRM == 0\n"
" int mg=get_local_id(0)*MWI;\n"
" #elif STRM == 1\n"
" int mg=get_local_id(0)*VWM;\n"
" #endif\n"
" #if STRN == 0\n"
" int ng=get_local_id(1)*(NWI/VWN);\n"
" #elif STRN == 1\n"
" int ng=get_local_id(1);\n"
" #endif\n"
" int idm=mg+GetGroupID0()*MWG;\n"
" int idn=ng+GetGroupID1()*(NWG/VWN);\n"
" \n"
" res.index[0]=idm;\n"
" res.index[1]=idn;\n"
" return res;\n"
"}\n"
"// layout : [M,N]\n"
"INLINE_FUNC void StoreResultsN(__global realN* cgn,COMPUTE_FLOATN c_value,\n"
" const INT2 baseOffset,\n"
" #if BIAS_TYPE>0\n"
" #if BIAS_TYPE>1\n"
" __global realN* egm,\n"
" #else\n"
" realN* epm,\n"
" #endif\n"
" #endif\n"
" const int _mi,const int _ni,\n"
" const int cstride/*kSizeN*/,const int dstride/*kSizeN*/,const PRECISION_COMPUTE alpha,const PRECISION_COMPUTE beta) {\n"
" #if STRM == 0\n"
" int idm=_mi+baseOffset.index[0];\n"
" #elif STRM == 1\n"
" int idm=_mi%VWM+baseOffset.index[0]+(_mi/VWM)*VWM*MDIMC;\n"
" #endif\n"
" #if STRN == 0\n"
" int idn=_ni+baseOffset.index[1];\n"
" #elif STRN == 1\n"
" int idn=baseOffset.index[1]+_ni*NDIMC;\n"
" #endif\n"
" int index=idm*(cstride/VWN)+idn;\n"
" \n"
" COMPUTE_FLOATN result=c_value;\n"
" \n"
" // The final multiplication with alpha (in case beta == 0)\n"
" #ifdef ONLY_HAVE_ALPHA\n"
" COMPUTE_FLOATN xval=c_value;\n"
" #if VWN == 1\n"
" Multiply(result,alpha,xval);\n"
" #elif VWN == 2\n"
" Multiply(result.x,alpha,xval.x);\n"
" Multiply(result.y,alpha,xval.y);\n"
" #elif VWN == 4\n"
" Multiply(result.x,alpha,xval.x);\n"
" Multiply(result.y,alpha,xval.y);\n"
" Multiply(result.z,alpha,xval.z);\n"
" Multiply(result.w,alpha,xval.w);\n"
" #elif VWN == 8\n"
" Multiply(result.s0,alpha,xval.s0);\n"
" Multiply(result.s1,alpha,xval.s1);\n"
" Multiply(result.s2,alpha,xval.s2);\n"
" Multiply(result.s3,alpha,xval.s3);\n"
" Multiply(result.s4,alpha,xval.s4);\n"
" Multiply(result.s5,alpha,xval.s5);\n"
" Multiply(result.s6,alpha,xval.s6);\n"
" Multiply(result.s7,alpha,xval.s7);\n"
" #elif VWN == 16\n"
" Multiply(result.s0,alpha,xval.s0);\n"
" Multiply(result.s1,alpha,xval.s1);\n"
" Multiply(result.s2,alpha,xval.s2);\n"
" Multiply(result.s3,alpha,xval.s3);\n"
" Multiply(result.s4,alpha,xval.s4);\n"
" Multiply(result.s5,alpha,xval.s5);\n"
" Multiply(result.s6,alpha,xval.s6);\n"
" Multiply(result.s7,alpha,xval.s7);\n"
" Multiply(result.s8,alpha,xval.s8);\n"
" Multiply(result.s9,alpha,xval.s9);\n"
" Multiply(result.sA,alpha,xval.sA);\n"
" Multiply(result.sB,alpha,xval.sB);\n"
" Multiply(result.sC,alpha,xval.sC);\n"
" Multiply(result.sD,alpha,xval.sD);\n"
" Multiply(result.sE,alpha,xval.sE);\n"
" Multiply(result.sF,alpha,xval.sF);\n"
" #endif\n"
" #endif\n"
" // The final multiplication with alpha and the addition with beta*C\n"
" #ifdef HAVE_ALPHA_BETA\n"
" COMPUTE_FLOATN xval=c_value;\n"
" COMPUTE_FLOATN yval=CONVERT_COMPUTE_FLOATN(cgn[index]);\n"
" #if VWN == 1\n"
" AXPBY(result,alpha,xval,beta,yval);\n"
" #elif VWN == 2\n"
" AXPBY(result.x,alpha,xval.x,beta,yval.x);\n"
" AXPBY(result.y,alpha,xval.y,beta,yval.y);\n"
" #elif VWN == 4\n"
" AXPBY(result.x,alpha,xval.x,beta,yval.x);\n"
" AXPBY(result.y,alpha,xval.y,beta,yval.y);\n"
" AXPBY(result.z,alpha,xval.z,beta,yval.z);\n"
" AXPBY(result.w,alpha,xval.w,beta,yval.w);\n"
" #elif VWN == 8\n"
" AXPBY(result.s0,alpha,xval.s0,beta,yval.s0);\n"
" AXPBY(result.s1,alpha,xval.s1,beta,yval.s1);\n"
" AXPBY(result.s2,alpha,xval.s2,beta,yval.s2);\n"
" AXPBY(result.s3,alpha,xval.s3,beta,yval.s3);\n"
" AXPBY(result.s4,alpha,xval.s4,beta,yval.s4);\n"
" AXPBY(result.s5,alpha,xval.s5,beta,yval.s5);\n"
" AXPBY(result.s6,alpha,xval.s6,beta,yval.s6);\n"
" AXPBY(result.s7,alpha,xval.s7,beta,yval.s7);\n"
" #elif VWN == 16\n"
" AXPBY(result.s0,alpha,xval.s0,beta,yval.s0);\n"
" AXPBY(result.s1,alpha,xval.s1,beta,yval.s1);\n"
" AXPBY(result.s2,alpha,xval.s2,beta,yval.s2);\n"
" AXPBY(result.s3,alpha,xval.s3,beta,yval.s3);\n"
" AXPBY(result.s4,alpha,xval.s4,beta,yval.s4);\n"
" AXPBY(result.s5,alpha,xval.s5,beta,yval.s5);\n"
" AXPBY(result.s6,alpha,xval.s6,beta,yval.s6);\n"
" AXPBY(result.s7,alpha,xval.s7,beta,yval.s7);\n"
" AXPBY(result.s8,alpha,xval.s8,beta,yval.s8);\n"
" AXPBY(result.s9,alpha,xval.s9,beta,yval.s9);\n"
" AXPBY(result.sA,alpha,xval.sA,beta,yval.sA);\n"
" AXPBY(result.sB,alpha,xval.sB,beta,yval.sB);\n"
" AXPBY(result.sC,alpha,xval.sC,beta,yval.sC);\n"
" AXPBY(result.sD,alpha,xval.sD,beta,yval.sD);\n"
" AXPBY(result.sE,alpha,xval.sE,beta,yval.sE);\n"
" AXPBY(result.sF,alpha,xval.sF,beta,yval.sF);\n"
" #endif\n"
" #endif\n"
" \n"
" \n"
"#if BIAS_TYPE>0\n"
" #if BIAS_TYPE == 1\n"
" COMPUTE_FLOATN eval=CONVERT_COMPUTE_FLOATN(epm[_ni]);\n"
" #elif BIAS_TYPE == 5\n"
" int index_bias=idm*(dstride/VWN)+idn;\n"
" intN eval=((__global intN*)egm)[index_bias];\n"
" #else\n"
" int index_bias=idm*(dstride/VWN)+idn;\n"
" COMPUTE_FLOATN eval=CONVERT_COMPUTE_FLOATN(egm[index_bias]);\n"
" #endif\n"
" \n"
" #if VWN == 1\n"
" DEAL_BIAS(result,eval);\n"
" #ifdef RELU\n"
" result=fmax(result,(COMPUTE_FLOATN)0);\n"
" #endif\n"
" #ifdef RELU6\n"
" result=clamp(result,(COMPUTE_FLOATN)0,(COMPUTE_FLOATN)6);\n"
" #endif\n"
" #elif VWN == 2\n"
" DEAL_BIAS(result.x,eval.x);\n"
" DEAL_BIAS(result.y,eval.y);\n"
" #ifdef RELU\n"
" result=fmax(result,(COMPUTE_FLOATN)0);\n"
" #endif\n"
" #ifdef RELU6\n"
" result=clamp(result,(COMPUTE_FLOATN)0,(COMPUTE_FLOATN)6);\n"
" #endif\n"
" #elif VWN == 4\n"
" DEAL_BIAS(result.x,eval.x);\n"
" DEAL_BIAS(result.y,eval.y);\n"
" DEAL_BIAS(result.z,eval.z);\n"
" DEAL_BIAS(result.w,eval.w);\n"
" #ifdef RELU\n"
" result=fmax(result,(COMPUTE_FLOATN)0);\n"
" #endif\n"
" #ifdef RELU6\n"
" result=clamp(result,(COMPUTE_FLOATN)0,(COMPUTE_FLOATN)6);\n"
" #endif\n"
" #elif VWN == 8\n"
" DEAL_BIAS(result.s0,eval.s0);\n"
" DEAL_BIAS(result.s1,eval.s1);\n"
" DEAL_BIAS(result.s2,eval.s2);\n"
" DEAL_BIAS(result.s3,eval.s3);\n"
" DEAL_BIAS(result.s4,eval.s4);\n"
" DEAL_BIAS(result.s5,eval.s5);\n"
" DEAL_BIAS(result.s6,eval.s6);\n"
" DEAL_BIAS(result.s7,eval.s7);\n"
" #ifdef RELU\n"
" result=fmax(result,(COMPUTE_FLOATN)0);\n"
" #endif\n"
" #ifdef RELU6\n"
" result=clamp(result,(COMPUTE_FLOATN)0,(COMPUTE_FLOATN)6);\n"
" #endif\n"
" #elif VWN == 16\n"
" DEAL_BIAS(result.s0,eval.s0);\n"
" DEAL_BIAS(result.s1,eval.s1);\n"
" DEAL_BIAS(result.s2,eval.s2);\n"
" DEAL_BIAS(result.s3,eval.s3);\n"
" DEAL_BIAS(result.s4,eval.s4);\n"
" DEAL_BIAS(result.s5,eval.s5);\n"
" DEAL_BIAS(result.s6,eval.s6);\n"
" DEAL_BIAS(result.s7,eval.s7);\n"
" DEAL_BIAS(result.s8,eval.s8);\n"
" DEAL_BIAS(result.s9,eval.s9);\n"
" DEAL_BIAS(result.sA,eval.sA);\n"
" DEAL_BIAS(result.sB,eval.sB);\n"
" DEAL_BIAS(result.sC,eval.sC);\n"
" DEAL_BIAS(result.sD,eval.sD);\n"
" DEAL_BIAS(result.sE,eval.sE);\n"
" DEAL_BIAS(result.sF,eval.sF);\n"
" #ifdef RELU\n"
" result=fmax(result,(COMPUTE_FLOATN)0);\n"
" #endif\n"
" #ifdef RELU6\n"
" result=clamp(result,(COMPUTE_FLOATN)0,(COMPUTE_FLOATN)6);\n"
" #endif\n"
" #endif\n"
"#endif\n"
" cgn[index]=CONVERT_FLOATN(result);\n"
"}\n"
"// Main body of the matrix-multiplication algorithm. It calls various (inlined) functions.\n"
"INLINE_FUNC void XgemmBody(const int kSizeM,const int kSizeN,const int kSizeK,const int4 stride,\n"
" const __global realM* restrict agm,const __global realN* restrict bgm,\n"
" #if BIAS_TYPE>0\n"
" __global realN* restrict egm,\n"
" #endif\n"
" __global realM* cgm,const real_arg alpha,const real_arg beta\n"
" #if SA == 1 && SB == 1\n"
" ,LOCAL_PTR realM* alm,LOCAL_PTR realN* blm\n"
" #elif SA == 1\n"
" ,LOCAL_PTR realM* alm\n"
" #elif SB == 1\n"
" ,LOCAL_PTR realN* blm\n"
" #endif\n"
" ) {\n"
" #ifdef OUTPUTMN\n"
" #pragma promote_to_registers\n"
" COMPUTE_FLOATN cpn[MWI*(NWI/VWN)]; // MWI*NWI\n"
" #else\n"
" #pragma promote_to_registers\n"
" COMPUTE_FLOATM cpm[NWI*(MWI/VWM)]; // NWI*MWI\n"
" #endif\n"
" // Combined thread identifier (volatile to disable caching)\n"
" #if SA == 1 || SB == 1\n"
" volatile int tid=get_local_id(0)+MDIMC*get_local_id(1);\n"
" #endif\n"
" // Initializes the accumulation registers\n"
" #ifdef OUTPUTMN\n"
" #pragma unroll\n"
" for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n"
" #pragma unroll\n"
" for (int _mi=0; _mi<MWI; _mi += 1) {\n"
" cpn[_mi*(NWI/VWN)+_ni]=InitAccRegistersN();\n"
" }\n"
" }\n"
" #else\n"
" #pragma unroll\n"
" for (int _mi=0; _mi<MWI/VWM; _mi += 1) {\n"
" #pragma unroll\n"
" for (int _ni=0; _ni<NWI; _ni += 1) {\n"
" cpm[_ni*(MWI/VWM)+_mi]=InitAccRegisters();\n"
" }\n"
" }\n"
" #endif\n"
" // Loops over all workgroup tiles\n"
" #if SA == 1 || SB == 1\n"
" // Allocates workitem-private memory (registers)\n"
" #pragma promote_to_registers\n"
" COMPUTE_FLOATM apm[MWI/VWM]; // MWI*1\n"
" #pragma promote_to_registers\n"
" COMPUTE_FLOATN bpm[NWI/VWN]; // 1*NWI\n"
" \n"
" for (int kwg=0; kwg<kSizeK; kwg += KWG) {\n"
" // Loads data: off-chip --> local (matrix A)\n"
" #if SA == 1\n"
" GlobalToLocalA(agm,alm,kSizeM,tid,kwg);\n"
" #endif\n"
" // Loads data: off-chip --> local (matrix B)\n"
" #if SB == 1\n"
" GlobalToLocalB(bgm,blm,kSizeN,tid,kwg);\n"
" #endif\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" // Loops over all workitem tiles,unrolled by a factor KWI\n"
" for (int pwi=0; pwi<KWG; pwi += KWI) {\n"
" #pragma unroll\n"
" for (int _pit=0; _pit<KWI; _pit += 1) {\n"
" #if SA == 0 || SB == 0\n"
" int idk=kwg+pwi+_pit;\n"
" #endif\n"
" int kg=pwi+_pit;\n"
" // Loads matrix A (kernel 0) or matrix B (kernel 1)\n"
" #pragma unroll\n"
" for (int _mi=0; _mi<MWI/VWM; _mi += 1) {\n"
" // Loads data: local --> private (matrix A)\n"
" #if SA == 1\n"
" apm[_mi]=CONVERT_COMPUTE_FLOATM(LocalToPrivateA(alm,_mi,kg));\n"
" // Loads data: off-chip --> private (matrix A)\n"
" #elif SA == 0\n"
" apm[_mi]=CONVERT_COMPUTE_FLOATM(GlobalToPrivateA(agm,_mi,kSizeM,idk));\n"
" #endif\n"
" }\n"
" // Loads matrix B (kernel 0) or matrix A (kernel 1)\n"
" #pragma unroll\n"
" for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n"
" // Loads data: local --> private (matrix B)\n"
" #if SB == 1\n"
" bpm[_ni]=CONVERT_COMPUTE_FLOATN(LocalToPrivateB(blm,_ni,kg));\n"
" // Loads data: off-chip --> private (matrix B)\n"
" #else\n"
" bpm[_ni]=CONVERT_COMPUTE_FLOATN(GlobalToPrivateB(bgm,_ni,kSizeN,idk));\n"
" #endif\n"
" }\n"
" // Performs the accumulation (Cpm += Apm*Bpm)\n"
" #ifdef OUTPUTMN\n"
" #pragma unroll\n"
" for (int _mi=0; _mi<MWI/VWM; _mi += 1) {\n"
" #pragma unroll\n"
" for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n"
" const COMPUTE_FLOATM aval=apm[_mi];\n"
" #if VWM == 1\n"
" // [MWI/VWM,VWM,NWI/VWN,VWN]\n"
" cpn[(_mi*VWM+0)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0)*(NWI/VWN)+_ni],aval,bpm[_ni]);\n"
" #elif VWM == 2\n"
" cpn[(_mi*VWM+0)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0)*(NWI/VWN)+_ni],aval.x,bpm[_ni]);\n"
" cpn[(_mi*VWM+1)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+1)*(NWI/VWN)+_ni],aval.y,bpm[_ni]);\n"
" #elif VWM == 4\n"
" cpn[(_mi*VWM+0)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0)*(NWI/VWN)+_ni],aval.x,bpm[_ni]);\n"
" cpn[(_mi*VWM+1)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+1)*(NWI/VWN)+_ni],aval.y,bpm[_ni]);\n"
" cpn[(_mi*VWM+2)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+2)*(NWI/VWN)+_ni],aval.z,bpm[_ni]);\n"
" cpn[(_mi*VWM+3)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+3)*(NWI/VWN)+_ni],aval.w,bpm[_ni]);\n"
" #elif VWM == 8\n"
" cpn[(_mi*VWM+0)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0)*(NWI/VWN)+_ni],aval.s0,bpm[_ni]);\n"
" cpn[(_mi*VWM+1)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+1)*(NWI/VWN)+_ni],aval.s1,bpm[_ni]);\n"
" cpn[(_mi*VWM+2)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+2)*(NWI/VWN)+_ni],aval.s2,bpm[_ni]);\n"
" cpn[(_mi*VWM+3)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+3)*(NWI/VWN)+_ni],aval.s3,bpm[_ni]);\n"
" cpn[(_mi*VWM+4)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+4)*(NWI/VWN)+_ni],aval.s4,bpm[_ni]);\n"
" cpn[(_mi*VWM+5)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+5)*(NWI/VWN)+_ni],aval.s5,bpm[_ni]);\n"
" cpn[(_mi*VWM+6)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+6)*(NWI/VWN)+_ni],aval.s6,bpm[_ni]);\n"
" cpn[(_mi*VWM+7)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+7)*(NWI/VWN)+_ni],aval.s7,bpm[_ni]);\n"
" #elif VWM == 16\n"
" cpn[(_mi*VWM+0 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0 )*(NWI/VWN)+_ni],aval.s0,bpm[_ni]);\n"
" cpn[(_mi*VWM+1 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+1 )*(NWI/VWN)+_ni],aval.s1,bpm[_ni]);\n"
" cpn[(_mi*VWM+2 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+2 )*(NWI/VWN)+_ni],aval.s2,bpm[_ni]);\n"
" cpn[(_mi*VWM+3 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+3 )*(NWI/VWN)+_ni],aval.s3,bpm[_ni]);\n"
" cpn[(_mi*VWM+4 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+4 )*(NWI/VWN)+_ni],aval.s4,bpm[_ni]);\n"
" cpn[(_mi*VWM+5 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+5 )*(NWI/VWN)+_ni],aval.s5,bpm[_ni]);\n"
" cpn[(_mi*VWM+6 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+6 )*(NWI/VWN)+_ni],aval.s6,bpm[_ni]);\n"
" cpn[(_mi*VWM+7 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+7 )*(NWI/VWN)+_ni],aval.s7,bpm[_ni]);\n"
" cpn[(_mi*VWM+8 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+8 )*(NWI/VWN)+_ni],aval.s8,bpm[_ni]);\n"
" cpn[(_mi*VWM+9 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+9 )*(NWI/VWN)+_ni],aval.s9,bpm[_ni]);\n"
" cpn[(_mi*VWM+10)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+10)*(NWI/VWN)+_ni],aval.sA,bpm[_ni]);\n"
" cpn[(_mi*VWM+11)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+11)*(NWI/VWN)+_ni],aval.sB,bpm[_ni]);\n"
" cpn[(_mi*VWM+12)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+12)*(NWI/VWN)+_ni],aval.sC,bpm[_ni]);\n"
" cpn[(_mi*VWM+13)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+13)*(NWI/VWN)+_ni],aval.sD,bpm[_ni]);\n"
" cpn[(_mi*VWM+14)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+14)*(NWI/VWN)+_ni],aval.sE,bpm[_ni]);\n"
" cpn[(_mi*VWM+15)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+15)*(NWI/VWN)+_ni],aval.sF,bpm[_ni]);\n"
" #endif\n"
" }\n"
" }\n"
" #else\n"
" #pragma unroll\n"
" for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n"
" #pragma unroll\n"
" for (int _mi=0; _mi<MWI/VWM; _mi += 1) {\n"
" const COMPUTE_FLOATM aval=apm[_mi];\n"
" #if VWN == 1\n"
" cpm[(_ni*VWN+0)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0)*(MWI/VWM)+_mi],aval,bpm[_ni]);\n"
" #elif VWN == 2\n"
" cpm[(_ni*VWN+0)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0)*(MWI/VWM)+_mi],aval,bpm[_ni].x);\n"
" cpm[(_ni*VWN+1)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+1)*(MWI/VWM)+_mi],aval,bpm[_ni].y);\n"
" #elif VWN == 4\n"
" cpm[(_ni*VWN+0)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0)*(MWI/VWM)+_mi],aval,bpm[_ni].x);\n"
" cpm[(_ni*VWN+1)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+1)*(MWI/VWM)+_mi],aval,bpm[_ni].y);\n"
" cpm[(_ni*VWN+2)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+2)*(MWI/VWM)+_mi],aval,bpm[_ni].z);\n"
" cpm[(_ni*VWN+3)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+3)*(MWI/VWM)+_mi],aval,bpm[_ni].w);\n"
" #elif VWN == 8\n"
" cpm[(_ni*VWN+0)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0)*(MWI/VWM)+_mi],aval,bpm[_ni].s0);\n"
" cpm[(_ni*VWN+1)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+1)*(MWI/VWM)+_mi],aval,bpm[_ni].s1);\n"
" cpm[(_ni*VWN+2)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+2)*(MWI/VWM)+_mi],aval,bpm[_ni].s2);\n"
" cpm[(_ni*VWN+3)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+3)*(MWI/VWM)+_mi],aval,bpm[_ni].s3);\n"
" cpm[(_ni*VWN+4)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+4)*(MWI/VWM)+_mi],aval,bpm[_ni].s4);\n"
" cpm[(_ni*VWN+5)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+5)*(MWI/VWM)+_mi],aval,bpm[_ni].s5);\n"
" cpm[(_ni*VWN+6)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+6)*(MWI/VWM)+_mi],aval,bpm[_ni].s6);\n"
" cpm[(_ni*VWN+7)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+7)*(MWI/VWM)+_mi],aval,bpm[_ni].s7);\n"
" #elif VWN == 16\n"
" cpm[(_ni*VWN+0 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0 )*(MWI/VWM)+_mi],aval,bpm[_ni].s0);\n"
" cpm[(_ni*VWN+1 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+1 )*(MWI/VWM)+_mi],aval,bpm[_ni].s1);\n"
" cpm[(_ni*VWN+2 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+2 )*(MWI/VWM)+_mi],aval,bpm[_ni].s2);\n"
" cpm[(_ni*VWN+3 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+3 )*(MWI/VWM)+_mi],aval,bpm[_ni].s3);\n"
" cpm[(_ni*VWN+4 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+4 )*(MWI/VWM)+_mi],aval,bpm[_ni].s4);\n"
" cpm[(_ni*VWN+5 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+5 )*(MWI/VWM)+_mi],aval,bpm[_ni].s5);\n"
" cpm[(_ni*VWN+6 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+6 )*(MWI/VWM)+_mi],aval,bpm[_ni].s6);\n"
" cpm[(_ni*VWN+7 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+7 )*(MWI/VWM)+_mi],aval,bpm[_ni].s7);\n"
" cpm[(_ni*VWN+8 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+8 )*(MWI/VWM)+_mi],aval,bpm[_ni].s8);\n"
" cpm[(_ni*VWN+9 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+9 )*(MWI/VWM)+_mi],aval,bpm[_ni].s9);\n"
" cpm[(_ni*VWN+10)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+10)*(MWI/VWM)+_mi],aval,bpm[_ni].sA);\n"
" cpm[(_ni*VWN+11)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+11)*(MWI/VWM)+_mi],aval,bpm[_ni].sB);\n"
" cpm[(_ni*VWN+12)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+12)*(MWI/VWM)+_mi],aval,bpm[_ni].sC);\n"
" cpm[(_ni*VWN+13)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+13)*(MWI/VWM)+_mi],aval,bpm[_ni].sD);\n"
" cpm[(_ni*VWN+14)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+14)*(MWI/VWM)+_mi],aval,bpm[_ni].sE);\n"
" cpm[(_ni*VWN+15)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+15)*(MWI/VWM)+_mi],aval,bpm[_ni].sF);\n"
" #endif\n"
" }\n"
" }\n"
" #endif\n"
" }\n"
" }\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" #else\n"
" // Allocates workitem-private memory (registers)\n"
" int baseIndexA=GlobalIndexA();\n"
" int baseIndexB=GlobalIndexB();\n"
" #pragma unroll\n"
" for (int _kj=0; _kj<kSizeK; _kj += 4) {\n"
" #ifdef OUTPUTMN\n"
" #pragma promote_to_registers\n"
" COMPUTE_FLOATN bpm[NWI/VWN]; // 1*NWI\n"
" \n"
" #pragma unroll\n"
" for(int _ki=0; _ki<4; _ki += 1) {\n"
" int idk=_kj+_ki;\n"
" #pragma unroll\n"
" for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n"
" // Loads data: off-chip --> private (matrix B)\n"
" bpm[_ni]=CONVERT_COMPUTE_FLOATN(GlobalToPrivateOptB(bgm,baseIndexB,_ni,stride.s1/*kSizeN*/,idk));\n"
" }\n"
" #pragma unroll\n"
" for (int _mi=0; _mi<MWI/VWM; _mi += 1) {\n"
" const COMPUTE_FLOATM aval=CONVERT_COMPUTE_FLOATM(GlobalToPrivateOptA(agm,baseIndexA,_mi,stride.s0/*kSizeM*/,idk));\n"
" #pragma unroll\n"
" for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n"
" #if VWM == 1\n"
" // [MWI/VWM,VWM,NWI/VWN,VWN]\n"
" cpn[(_mi*VWM+0)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0)*(NWI/VWN)+_ni],aval,bpm[_ni]);\n"
" #elif VWM == 2\n"
" cpn[(_mi*VWM+0)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0)*(NWI/VWN)+_ni],aval.x,bpm[_ni]);\n"
" cpn[(_mi*VWM+1)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+1)*(NWI/VWN)+_ni],aval.y,bpm[_ni]);\n"
" #elif VWM == 4\n"
" cpn[(_mi*VWM+0)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0)*(NWI/VWN)+_ni],aval.x,bpm[_ni]);\n"
" cpn[(_mi*VWM+1)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+1)*(NWI/VWN)+_ni],aval.y,bpm[_ni]);\n"
" cpn[(_mi*VWM+2)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+2)*(NWI/VWN)+_ni],aval.z,bpm[_ni]);\n"
" cpn[(_mi*VWM+3)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+3)*(NWI/VWN)+_ni],aval.w,bpm[_ni]);\n"
" #elif VWM == 8\n"
" cpn[(_mi*VWM+0)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0)*(NWI/VWN)+_ni],aval.s0,bpm[_ni]);\n"
" cpn[(_mi*VWM+1)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+1)*(NWI/VWN)+_ni],aval.s1,bpm[_ni]);\n"
" cpn[(_mi*VWM+2)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+2)*(NWI/VWN)+_ni],aval.s2,bpm[_ni]);\n"
" cpn[(_mi*VWM+3)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+3)*(NWI/VWN)+_ni],aval.s3,bpm[_ni]);\n"
" cpn[(_mi*VWM+4)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+4)*(NWI/VWN)+_ni],aval.s4,bpm[_ni]);\n"
" cpn[(_mi*VWM+5)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+5)*(NWI/VWN)+_ni],aval.s5,bpm[_ni]);\n"
" cpn[(_mi*VWM+6)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+6)*(NWI/VWN)+_ni],aval.s6,bpm[_ni]);\n"
" cpn[(_mi*VWM+7)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+7)*(NWI/VWN)+_ni],aval.s7,bpm[_ni]);\n"
" #elif VWM == 16\n"
" cpn[(_mi*VWM+0 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0 )*(NWI/VWN)+_ni],aval.s0,bpm[_ni]);\n"
" cpn[(_mi*VWM+1 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+1 )*(NWI/VWN)+_ni],aval.s1,bpm[_ni]);\n"
" cpn[(_mi*VWM+2 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+2 )*(NWI/VWN)+_ni],aval.s2,bpm[_ni]);\n"
" cpn[(_mi*VWM+3 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+3 )*(NWI/VWN)+_ni],aval.s3,bpm[_ni]);\n"
" cpn[(_mi*VWM+4 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+4 )*(NWI/VWN)+_ni],aval.s4,bpm[_ni]);\n"
" cpn[(_mi*VWM+5 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+5 )*(NWI/VWN)+_ni],aval.s5,bpm[_ni]);\n"
" cpn[(_mi*VWM+6 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+6 )*(NWI/VWN)+_ni],aval.s6,bpm[_ni]);\n"
" cpn[(_mi*VWM+7 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+7 )*(NWI/VWN)+_ni],aval.s7,bpm[_ni]);\n"
" cpn[(_mi*VWM+8 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+8 )*(NWI/VWN)+_ni],aval.s8,bpm[_ni]);\n"
" cpn[(_mi*VWM+9 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+9 )*(NWI/VWN)+_ni],aval.s9,bpm[_ni]);\n"
" cpn[(_mi*VWM+10)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+10)*(NWI/VWN)+_ni],aval.sA,bpm[_ni]);\n"
" cpn[(_mi*VWM+11)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+11)*(NWI/VWN)+_ni],aval.sB,bpm[_ni]);\n"
" cpn[(_mi*VWM+12)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+12)*(NWI/VWN)+_ni],aval.sC,bpm[_ni]);\n"
" cpn[(_mi*VWM+13)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+13)*(NWI/VWN)+_ni],aval.sD,bpm[_ni]);\n"
" cpn[(_mi*VWM+14)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+14)*(NWI/VWN)+_ni],aval.sE,bpm[_ni]);\n"
" cpn[(_mi*VWM+15)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+15)*(NWI/VWN)+_ni],aval.sF,bpm[_ni]);\n"
" #endif\n"
" }\n"
" }\n"
" }\n"
" #else\n"
" \n"
" #pragma promote_to_registers\n"
" COMPUTE_FLOATM apm[MWI/VWM]; // MWI*1\n"
" #pragma unroll\n"
" for(int _ki=0; _ki<4; _ki += 1) {\n"
" int idk=_kj+_ki;\n"
" #pragma unroll\n"
" for (int _mi=0; _mi<MWI/VWM; _mi += 1) {\n"
" // Loads data: off-chip --> private (matrix B)\n"
" apm[_mi]=CONVERT_COMPUTE_FLOATM(GlobalToPrivateOptA(agm,baseIndexA,_mi,stride.s0/*kSizeM*/,idk));\n"
" }\n"
" #pragma unroll\n"
" for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n"
" const COMPUTE_FLOATN bval=CONVERT_COMPUTE_FLOATN(GlobalToPrivateOptB(bgm,baseIndexB,_ni,stride.s1/*kSizeN*/,idk));\n"
" #pragma unroll\n"
" for (int _mi=0; _mi<MWI/VWM; _mi += 1) {\n"
" const COMPUTE_FLOATM aval=apm[_mi];\n"
" #if VWN == 1\n"
" cpm[(_ni*VWN+0)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0)*(MWI/VWM)+_mi],aval,bval);\n"
" #elif VWN == 2\n"
" cpm[(_ni*VWN+0)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0)*(MWI/VWM)+_mi],aval,bval.x);\n"
" cpm[(_ni*VWN+1)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+1)*(MWI/VWM)+_mi],aval,bval.y);\n"
" #elif VWN == 4\n"
" cpm[(_ni*VWN+0)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0)*(MWI/VWM)+_mi],aval,bval.x);\n"
" cpm[(_ni*VWN+1)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+1)*(MWI/VWM)+_mi],aval,bval.y);\n"
" cpm[(_ni*VWN+2)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+2)*(MWI/VWM)+_mi],aval,bval.z);\n"
" cpm[(_ni*VWN+3)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+3)*(MWI/VWM)+_mi],aval,bval.w);\n"
" #elif VWN == 8\n"
" cpm[(_ni*VWN+0)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0)*(MWI/VWM)+_mi],aval,bval.s0);\n"
" cpm[(_ni*VWN+1)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+1)*(MWI/VWM)+_mi],aval,bval.s1);\n"
" cpm[(_ni*VWN+2)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+2)*(MWI/VWM)+_mi],aval,bval.s2);\n"
" cpm[(_ni*VWN+3)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+3)*(MWI/VWM)+_mi],aval,bval.s3);\n"
" cpm[(_ni*VWN+4)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+4)*(MWI/VWM)+_mi],aval,bval.s4);\n"
" cpm[(_ni*VWN+5)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+5)*(MWI/VWM)+_mi],aval,bval.s5);\n"
" cpm[(_ni*VWN+6)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+6)*(MWI/VWM)+_mi],aval,bval.s6);\n"
" cpm[(_ni*VWN+7)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+7)*(MWI/VWM)+_mi],aval,bval.s7);\n"
" #elif VWN == 16\n"
" cpm[(_ni*VWN+0 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0 )*(MWI/VWM)+_mi],aval,bval.s0);\n"
" cpm[(_ni*VWN+1 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+1 )*(MWI/VWM)+_mi],aval,bval.s1);\n"
" cpm[(_ni*VWN+2 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+2 )*(MWI/VWM)+_mi],aval,bval.s2);\n"
" cpm[(_ni*VWN+3 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+3 )*(MWI/VWM)+_mi],aval,bval.s3);\n"
" cpm[(_ni*VWN+4 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+4 )*(MWI/VWM)+_mi],aval,bval.s4);\n"
" cpm[(_ni*VWN+5 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+5 )*(MWI/VWM)+_mi],aval,bval.s5);\n"
" cpm[(_ni*VWN+6 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+6 )*(MWI/VWM)+_mi],aval,bval.s6);\n"
" cpm[(_ni*VWN+7 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+7 )*(MWI/VWM)+_mi],aval,bval.s7);\n"
" cpm[(_ni*VWN+8 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+8 )*(MWI/VWM)+_mi],aval,bval.s8);\n"
" cpm[(_ni*VWN+9 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+9 )*(MWI/VWM)+_mi],aval,bval.s9);\n"
" cpm[(_ni*VWN+10)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+10)*(MWI/VWM)+_mi],aval,bval.sA);\n"
" cpm[(_ni*VWN+11)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+11)*(MWI/VWM)+_mi],aval,bval.sB);\n"
" cpm[(_ni*VWN+12)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+12)*(MWI/VWM)+_mi],aval,bval.sC);\n"
" cpm[(_ni*VWN+13)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+13)*(MWI/VWM)+_mi],aval,bval.sD);\n"
" cpm[(_ni*VWN+14)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+14)*(MWI/VWM)+_mi],aval,bval.sE);\n"
" cpm[(_ni*VWN+15)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+15)*(MWI/VWM)+_mi],aval,bval.sF);\n"
" #endif\n"
" }\n"
" }\n"
" }\n"
" #endif\n"
" }\n"
" #endif\n"
" \n"
" #if GLOBAL_MEM_FENCE == 1\n"
" barrier(CLK_GLOBAL_MEM_FENCE);\n"
" #endif\n"
" #ifdef OUTPUTMN\n"
" INT2 baseOffset=StoreIndexN();\n"
" #if BIAS_TYPE == 1\n"
" #pragma promote_to_registers\n"
" realN epm[NWI/VWN]; // MWI*1\n"
" for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n"
" #if STRN == 0\n"
" int idn=_ni+baseOffset.index[1];\n"
" #elif STRN == 1\n"
" int idn=baseOffset.index[1]+_ni*NDIMC;\n"
" #endif\n"
" epm[_ni]=egm[idn];\n"
" }\n"
" #endif\n"
" \n"
" \n"
" \n"
" #pragma unroll\n"
" for (int _mi=0; _mi<MWI; _mi += 1) {\n"
" #pragma unroll\n"
" for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n"
" StoreResultsN((__global realN* )cgm,cpn[_mi*(NWI/VWN)+_ni],\n"
" baseOffset,\n"
" #if BIAS_TYPE>1\n"
" (__global realN*)egm,\n"
" #elif BIAS_TYPE == 1\n"
" (realN*)epm,\n"
" #endif\n"
" _mi,_ni,stride.s2,stride.s3,alpha,beta);\n"
" }\n"
" }\n"
" \n"
" #else\n"
" INT2 baseOffset=StoreIndexM();\n"
" // Stores an MWG*NWG tile of results and performs the multiplication with alpha and beta\n"
" const int cld=kSizeM;\n"
" \n"
" #pragma unroll\n"
" for (int _ni=0; _ni<NWI; _ni += 1) {\n"
" #pragma unroll\n"
" for (int _mi=0; _mi<MWI/VWM; _mi += 1) {\n"
" StoreResultsM(cgm,cpm[_ni*(MWI/VWM)+_mi],baseOffset,_mi,_ni,cld,alpha,beta);\n"
" }\n"
" }\n"
" #endif\n"
"}\n"
"// Main entry point of the kernel. This is the regular full version.\n"
"#if RELAX_WORKGROUP_SIZE == 1\n"
" __kernel\n"
"#else\n"
" __kernel __attribute__((reqd_work_group_size(MDIMC,NDIMC,1)))\n"
"#endif\n"
"void Xgemm(const int kSizeM,const int kSizeN,const int kSizeK,\n"
" const real_arg arg_alpha,\n"
" const real_arg arg_beta,\n"
" const __global realM* restrict agm,// [K,M]\n"
" const __global realN* restrict bgm,// [K,N]\n"
" #if BIAS_TYPE>0\n"
" __global realN* restrict egm,// [N]\n"
" #endif\n"
" __global realM* cgm,\n"
" __private const int4 offset,\n"
" __private const int4 stride\n"
") {\n"
" \n"
" // Adds the offsets (in case of use of a single temporary buffer for A,B,and C)\n"
" agm=(const __global realM*)((const __global real*)agm+offset.s0);\n"
" bgm=(const __global realN*)((const __global real*)bgm+offset.s1);\n"
" cgm=(__global realM*)((__global real*)cgm+offset.s2);\n"
" \n"
" #if BIAS_TYPE>0\n"
" egm=(__global realN*)((__global real*)egm+offset.s3);\n"
" #endif\n"
" // Allocates workgroup-private memory (local memory)\n"
" #if SA == 1\n"
" __local realM alm[KWG*MWG/VWM];\n"
" #endif\n"
" #if SB == 1\n"
" __local realN blm[KWG*NWG/VWN];\n"
" #endif\n"
" \n"
" // Computes the matrix-multiplication and stores the result in global memory\n"
" #if SA == 1 && SB == 1\n"
" XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm,bgm,\n"
" #if BIAS_TYPE>0\n"
" egm,\n"
" #endif\n"
" cgm,arg_alpha,arg_beta,alm,blm);\n"
" #elif SA == 1\n"
" XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm,bgm,\n"
" #if BIAS_TYPE>0\n"
" egm,\n"
" #endif\n"
" cgm,arg_alpha,arg_beta,alm);\n"
" #elif SB == 1\n"
" XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm,bgm,\n"
" #if BIAS_TYPE>0\n"
" egm,\n"
" #endif\n"
" cgm,arg_alpha,arg_beta,blm);\n"
" #else\n"
" XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm,bgm,\n"
" #if BIAS_TYPE>0\n"
" egm,\n"
" #endif\n"
" cgm,arg_alpha,arg_beta);\n"
" #endif\n"
"}\n"
"#if RELAX_WORKGROUP_SIZE == 1\n"
" __kernel\n"
"#else\n"
" __kernel __attribute__((reqd_work_group_size(MDIMC,NDIMC,1)))\n"
"#endif\n"
"void XgemmBatched(const int kSizeM,\n"
" const int kSizeN,\n"
" const int kSizeK,\n"
" const real_arg arg_alpha,\n"
" const real_arg arg_beta,\n"
" const __global realM* restrict agm,\n"
" const __global realN* restrict bgm,\n"
" #if BIAS_TYPE>0\n"
" __global realN* restrict egm,\n"
" #endif\n"
" __global realM* cgm,\n"
" const int4 batch_offset,// [batch_offset_a,batch_offset_b,batch_offset_c,batch_offset_e]\n"
" const int4 stride,// [stride_a,stride_b,stride_c,stride_e]\n"
" /*\n"
" total_batch -> [loop_y,loop_x]\n"
" with group batch -> [loop_y,loop_x/group_num]\n"
" group_size == loop_x/group_num\n"
" */\n"
" const int4 group // [group_num_a,group_num_b,group_num_e,loop_x]\n"
") {\n"
" const int batch=get_group_id(2);\n"
" \n"
" // Sets the offsets\n"
" const int a_offset=((batch/group.w)*group.x+(batch % group.w)/group.x)*batch_offset.x;\n"
" const int b_offset=((batch/group.w)*group.y+(batch % group.w)/group.y)*batch_offset.y;\n"
" const int c_offset=batch*batch_offset.z;\n"
" const __global realM* restrict agm_=&agm[a_offset/VWM];\n"
" const __global realN* restrict bgm_=&bgm[b_offset/VWN];\n"
" __global realM* restrict cgm_=&cgm[c_offset/VWM];\n"
" \n"
" #if BIAS_TYPE>0\n"
" const int e_offset=((batch/group.w)*group.z+(batch % group.w)/group.z)*batch_offset.w;\n"
" __global realN* restrict egm_=&egm[e_offset/VWN];\n"
" #endif\n"
" \n"
" // Allocates workgroup-private memory (local memory)\n"
" #if SA == 1\n"
" __local realM alm[KWG*MWG/VWM];\n"
" #endif\n"
" #if SB == 1\n"
" __local realN blm[KWG*NWG/VWN];\n"
" #endif\n"
" // Computes the matrix-multiplication and stores the result in global memory\n"
" #if SA == 1 && SB == 1\n"
" XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm_,bgm_,\n"
" #if BIAS_TYPE>0\n"
" egm_,\n"
" #endif\n"
" cgm_,arg_alpha,arg_beta,alm,blm);\n"
" #elif SA == 1\n"
" XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm_,bgm_,\n"
" #if BIAS_TYPE>0\n"
" egm_,\n"
" #endif\n"
" cgm_,arg_alpha,arg_beta,alm);\n"
" #elif SB == 1\n"
" XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm_,bgm_,\n"
" #if BIAS_TYPE>0\n"
" egm_,\n"
" #endif\n"
" cgm_,arg_alpha,arg_beta,blm);\n"
" #else\n"
" XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm_,bgm_,\n"
" #if BIAS_TYPE>0\n"
" egm_,\n"
" #endif\n"
" cgm_,arg_alpha,arg_beta);\n"
" #endif\n"
"}\n"
;
#endif
const char* cast = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_3_DIMS ""__private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void cast(GLOBAL_SIZE_3_DIMS\n"
" __read_only image2d_t input,\n"
" __write_only image2d_t output,\n"
" __private const int width,\n"
" __private const int height,\n"
" __private const int channelBlock\n"
" ) {\n"
" const int width_idx=get_global_id(0);\n"
" const int height_idx=get_global_id(1);\n"
" const int batch_channel_idx=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(width_idx,height_idx,batch_channel_idx);\n"
" \n"
" const int batch_idx=batch_channel_idx/channelBlock;\n"
" const int channel_idx=batch_channel_idx % channelBlock;\n"
" \n"
"#ifdef TO_BOOL\n"
" int4 value=convert_int4(RI_DATA(input,SAMPLER,(int2)(channel_idx*width+width_idx,batch_idx*height+height_idx)));\n"
" value=value == (int4)0 ? (int4)0 : (int4)1;\n"
" WI_DATA(output,(int2)(channel_idx*width+width_idx,batch_idx*height+height_idx),CONVERT_OUTPUT_I4(value));\n"
"#else\n"
" INPUT_TYPE_I4 value=RI_DATA(input,SAMPLER,(int2)(channel_idx*width+width_idx,batch_idx*height+height_idx));\n"
" WI_DATA(output,(int2)(channel_idx*width+width_idx,batch_idx*height+height_idx),CONVERT_OUTPUT_I4(value));\n"
"#endif\n"
"}\n"
;
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* buffer_convert_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0,__private const int global_size_dim1,\n"
"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n"
"#define GLOBAL_SIZE_3_DIMS __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"#define MNN_DATA_FORMAT_NCHW 0\n"
"#define MNN_DATA_FORMAT_NHWC 1\n"
"#define MNN_DATA_FORMAT_NC4HW4 2\n"
"#define MNN_DATA_FORMAT_C4NHW4 3\n"
"__kernel void buffer_convert_to_buffer(GLOBAL_SIZE_3_DIMS\n"
" __global const INPUT_TYPE *input_ptr,\n"
" __private const int4 shape,// N C H W\n"
" __global OUTPUT_TYPE *output_ptr\n"
") {\n"
" int wh=get_global_id(0);\n"
" int c=get_global_id(1);\n"
" int n=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(wh,c,n);\n"
" int w=wh % shape.w;\n"
" int h=wh/shape.w;\n"
" \n"
"#if INPUT_FORMAT == MNN_DATA_FORMAT_NCHW\n"
" int input_offset=((n*shape.y+c)*shape.z+h)*shape.w+w;\n"
"#elif INPUT_FORMAT == MNN_DATA_FORMAT_NHWC\n"
" int input_offset=((n*shape.z+h)*shape.w+w)*shape.y+c;\n"
"#elif INPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4\n"
" int input_offset=((((c/4)*shape.x+n)*shape.z+h)*shape.w+w)*4+(c % 4);\n"
"#endif\n"
"#if OUTPUT_FORMAT == MNN_DATA_FORMAT_NCHW\n"
" int output_offset=((n*shape.y+c)*shape.z+h)*shape.w+w;\n"
"#elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NHWC\n"
" int output_offset=((n*shape.z+h)*shape.w+w)*shape.y+c;\n"
"#elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4\n"
" int output_offset=((((c/4)*shape.x+n)*shape.z+h)*shape.w+w)*4+(c % 4);\n"
"#endif\n"
" output_ptr[output_offset]=input_ptr[input_offset];\n"
"}\n"
"__kernel void buffer_copy_to_buffer(GLOBAL_SIZE_2_DIMS\n"
" __global const INPUT_TYPE *input_ptr,\n"
" __global OUTPUT_TYPE *output_ptr,\n"
" __private const int size // N C H W\n"
") {\n"
" const int x=get_global_id(0);\n"
" const int y=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(x,y);\n"
" const int offset=x << 2;\n"
"#ifdef PACK_LEAVE\n"
" if(offset+3 >= size){\n"
" for(int i=0; i<size-offset; ++i){\n"
" output_ptr[offset+i]=(OUTPUT_TYPE)input_ptr[offset+i];\n"
" }\n"
" } else {\n"
"#endif\n"
" vstore4(CONVERT_OUTPUT4(vload4(0,input_ptr+offset)),0,output_ptr+offset);\n"
"#ifdef PACK_LEAVE\n"
" }\n"
"#endif\n"
"}\n"
"// convert kernel : from buffer(oihw) to image(oc/4 h w ,ic oc4)\n"
"__kernel void conv2d_filter_buffer_to_nc4hw4_buffer(GLOBAL_SIZE_2_DIMS\n"
" __global const FLOAT *input_ptr,\n"
" __private const int output_channel,\n"
" __private const int2 kernel_shape,\n"
" __private const int ic_h_w_size,\n"
" __private const int height_width_size,\n"
" __global FLOAT *output) {\n"
" int image_width_idx=get_global_id(0); // ic\n"
" int image_height_idx=get_global_id(1); // oc/4 h w\n"
" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n"
" const int input_channel_4_idx=image_width_idx;\n"
" const int output_channel_4_idx=(image_height_idx/height_width_size)*4;\n"
" const int height_width_idx=image_height_idx % height_width_size;\n"
" const int buffer_height_idx=height_width_idx/kernel_shape.y;\n"
" const int buffer_width_idx=height_width_idx % kernel_shape.y;\n"
" const int buffer_offset=output_channel_4_idx*ic_h_w_size+input_channel_4_idx*height_width_size +\n"
" buffer_height_idx*kernel_shape.y+buffer_width_idx;\n"
" FLOAT4 output_values=0;\n"
" if (output_channel_4_idx<output_channel) {\n"
" const int remain_channel=output_channel-output_channel_4_idx;\n"
" if (remain_channel >= 4) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(FLOAT)(*(input_ptr+offset));\n"
" offset=mad24(1,ic_h_w_size,offset);\n"
" output_values.y=(FLOAT)(*(input_ptr+offset));\n"
" offset += ic_h_w_size;\n"
" output_values.z=(FLOAT)(*(input_ptr+offset));\n"
" offset += ic_h_w_size;\n"
" output_values.w=(FLOAT)(*(input_ptr+offset));\n"
" } else if (remain_channel == 3) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(FLOAT)(*(input_ptr+offset));\n"
" offset=mad24(1,ic_h_w_size,offset);\n"
" output_values.y=(FLOAT)(*(input_ptr+offset));\n"
" offset += ic_h_w_size;\n"
" output_values.z=(FLOAT)(*(input_ptr+offset));\n"
" } else if (remain_channel == 2) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(FLOAT)(*(input_ptr+offset));\n"
" offset=mad24(1,ic_h_w_size,offset);\n"
" output_values.y=(FLOAT)(*(input_ptr+offset));\n"
" } else if (remain_channel == 1) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(FLOAT)(*(input_ptr+offset));\n"
" }\n"
" }\n"
" const int out_offset=(image_width_idx*height_width_size*((output_channel+3)/4)+image_height_idx)*4;\n"
" vstore4(output_values,0,output+out_offset);\n"
"}\n"
"// convert kernel : from buffer(oihw) to image(oc/4 h w ,ic oc4)\n"
"__kernel void conv2d_filter_buffer_to_nc4hw4_buffer_floatin(GLOBAL_SIZE_2_DIMS\n"
" __global const float *input_ptr,\n"
" __private const int output_channel,\n"
" __private const int2 kernel_shape,\n"
" __private const int ic_h_w_size,\n"
" __private const int height_width_size,\n"
" __global FLOAT *output) {\n"
" int image_width_idx=get_global_id(0); // ic\n"
" int image_height_idx=get_global_id(1); // oc/4 h w\n"
" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n"
" const int input_channel_4_idx=image_width_idx;\n"
" const int output_channel_4_idx=(image_height_idx/height_width_size)*4;\n"
" const int height_width_idx=image_height_idx % height_width_size;\n"
" const int buffer_height_idx=height_width_idx/kernel_shape.y;\n"
" const int buffer_width_idx=height_width_idx % kernel_shape.y;\n"
" const int buffer_offset=output_channel_4_idx*ic_h_w_size+input_channel_4_idx*height_width_size +\n"
" buffer_height_idx*kernel_shape.y+buffer_width_idx;\n"
" FLOAT4 output_values=0;\n"
" if (output_channel_4_idx<output_channel) {\n"
" const int remain_channel=output_channel-output_channel_4_idx;\n"
" if (remain_channel >= 4) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(FLOAT)(*(input_ptr+offset));\n"
" offset=mad24(1,ic_h_w_size,offset);\n"
" output_values.y=(FLOAT)(*(input_ptr+offset));\n"
" offset += ic_h_w_size;\n"
" output_values.z=(FLOAT)(*(input_ptr+offset));\n"
" offset += ic_h_w_size;\n"
" output_values.w=(FLOAT)(*(input_ptr+offset));\n"
" } else if (remain_channel == 3) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(FLOAT)(*(input_ptr+offset));\n"
" offset=mad24(1,ic_h_w_size,offset);\n"
" output_values.y=(FLOAT)(*(input_ptr+offset));\n"
" offset += ic_h_w_size;\n"
" output_values.z=(FLOAT)(*(input_ptr+offset));\n"
" } else if (remain_channel == 2) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(FLOAT)(*(input_ptr+offset));\n"
" offset=mad24(1,ic_h_w_size,offset);\n"
" output_values.y=(FLOAT)(*(input_ptr+offset));\n"
" } else if (remain_channel == 1) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(FLOAT)(*(input_ptr+offset));\n"
" }\n"
" }\n"
" const int out_offset=(image_width_idx*height_width_size*((output_channel+3)/4)+image_height_idx)*4;\n"
" vstore4(output_values,0,output+out_offset);\n"
"}\n"
"// convert kernel from buffer(mihw) to image(ic/4,ic4 h w m)\n"
"// but now dw only support m == 1\n"
"__kernel void dw_filter_buffer_to_nc4hw4_buffer(GLOBAL_SIZE_2_DIMS\n"
" __global const FLOAT *input_ptr,\n"
" __private const int4 kernel_shape,//[1,Cout,fh,fw]\n"
" __private const int height_width_size,\n"
" __global FLOAT *output) {\n"
" const int image_width_idx=get_global_id(0);//fh*fw\n"
" const int image_height_idx=get_global_id(1);//UP_DIV(Cout,4)\n"
" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n"
" FLOAT4 output_values=0;\n"
" if (kernel_shape.x == 1) {\n"
" const int input_channel_4_idx=image_height_idx*4;\n"
" const int buffer_height_idx=image_width_idx/kernel_shape.w;\n"
" const int buffer_width_idx=image_width_idx % kernel_shape.w;\n"
" const int buffer_offset =\n"
" mad24(mad24(input_channel_4_idx,kernel_shape.z,buffer_height_idx),kernel_shape.w,buffer_width_idx);\n"
" //input [1,Cout,fh,fw]\n"
" //index:[0,input_channel_4_idx,buffer_height_idx,buffer_width_idx]\n"
" const int remain_channel=kernel_shape.y-input_channel_4_idx;\n"
" if (input_channel_4_idx<kernel_shape.y) {\n"
" if (remain_channel >= 4) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(FLOAT)(*(input_ptr+offset));\n"
" offset += height_width_size;\n"
" output_values.y=(FLOAT)(*(input_ptr+offset));\n"
" offset += height_width_size;\n"
" output_values.z=(FLOAT)(*(input_ptr+offset));\n"
" offset += height_width_size;\n"
" output_values.w=(FLOAT)(*(input_ptr+offset));\n"
" } else if (remain_channel == 3) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(FLOAT)(*(input_ptr+offset));\n"
" offset += height_width_size;\n"
" output_values.y=(FLOAT)(*(input_ptr+offset));\n"
" offset += height_width_size;\n"
" output_values.z=(FLOAT)(*(input_ptr+offset));\n"
" } else if (remain_channel == 2) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(FLOAT)(*(input_ptr+offset));\n"
" offset += height_width_size;\n"
" output_values.y=(FLOAT)(*(input_ptr+offset));\n"
" } else if (remain_channel == 1) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(FLOAT)(*(input_ptr+offset));\n"
" }\n"
" }\n"
" }\n"
" //output NC4HW4 [1,fw*fh,1,Cout/4]x oc4\n"
" //index: [0,image_width_idx,0,image_height_idx]\n"
" const int out_offset=(image_width_idx*((kernel_shape.y+3)/4)+image_height_idx)*4;\n"
" vstore4(output_values,0,output+out_offset);\n"
"}\n"
"__kernel void dw_filter_buffer_to_nc4hw4_buffer_floatin(GLOBAL_SIZE_2_DIMS\n"
" __global const float *input_ptr,\n"
" __private const int4 kernel_shape,//[1,Cout,fh,fw]\n"
" __private const int height_width_size,\n"
" __global FLOAT *output) {\n"
" const int image_width_idx=get_global_id(0);//fh*fw\n"
" const int image_height_idx=get_global_id(1);//UP_DIV(Cout,4)\n"
" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n"
" FLOAT4 output_values=0;\n"
" if (kernel_shape.x == 1) {\n"
" const int input_channel_4_idx=image_height_idx*4;\n"
" const int buffer_height_idx=image_width_idx/kernel_shape.w;\n"
" const int buffer_width_idx=image_width_idx % kernel_shape.w;\n"
" const int buffer_offset =\n"
" mad24(mad24(input_channel_4_idx,kernel_shape.z,buffer_height_idx),kernel_shape.w,buffer_width_idx);\n"
" //input [1,Cout,fh,fw]\n"
" //index:[0,input_channel_4_idx,buffer_height_idx,buffer_width_idx]\n"
" const int remain_channel=kernel_shape.y-input_channel_4_idx;\n"
" if (input_channel_4_idx<kernel_shape.y) {\n"
" if (remain_channel >= 4) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(FLOAT)(*(input_ptr+offset));\n"
" offset += height_width_size;\n"
" output_values.y=(FLOAT)(*(input_ptr+offset));\n"
" offset += height_width_size;\n"
" output_values.z=(FLOAT)(*(input_ptr+offset));\n"
" offset += height_width_size;\n"
" output_values.w=(FLOAT)(*(input_ptr+offset));\n"
" } else if (remain_channel == 3) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(FLOAT)(*(input_ptr+offset));\n"
" offset += height_width_size;\n"
" output_values.y=(FLOAT)(*(input_ptr+offset));\n"
" offset += height_width_size;\n"
" output_values.z=(FLOAT)(*(input_ptr+offset));\n"
" } else if (remain_channel == 2) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(FLOAT)(*(input_ptr+offset));\n"
" offset += height_width_size;\n"
" output_values.y=(FLOAT)(*(input_ptr+offset));\n"
" } else if (remain_channel == 1) {\n"
" int offset=buffer_offset;\n"
" output_values.x=(FLOAT)(*(input_ptr+offset));\n"
" }\n"
" }\n"
" }\n"
" //output NC4HW4 [1,fw*fh,1,Cout/4]x oc4\n"
" //index: [0,image_width_idx,0,image_height_idx]\n"
" const int out_offset=(image_width_idx*((kernel_shape.y+3)/4)+image_height_idx)*4;\n"
" vstore4(output_values,0,output+out_offset);\n"
"}\n"
;
#endif
const char* matmul = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_2_DIMS ""__private const int global_size_dim0,__private const int global_size_dim1,\n"
"#define DEAL_NON_UNIFORM_DIM2(input1, input2) ""if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { ""return; ""}\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void matmul(GLOBAL_SIZE_2_DIMS __read_only image2d_t input_a,\n"
" __read_only image2d_t input_b,\n"
" #ifdef BIAS\n"
" __read_only image2d_t input_c,\n"
" #endif\n"
" __write_only image2d_t output_c,__private const int channels,\n"
" __private const int channel_blocks) {\n"
" const int width_blocks_idx=get_global_id(0);\n"
" const int height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(width_blocks_idx,height_idx);\n"
" FLOAT4 a;\n"
" FLOAT4 b0=0,b1=0,b2=0,b3=0;\n"
" #ifdef BIAS\n"
" FLOAT4 temp=RI_F(input_c,SAMPLER,(int2)(width_blocks_idx,0));\n"
" FLOAT result0=temp.x;\n"
" FLOAT result1=temp.y;\n"
" FLOAT result2=temp.z;\n"
" FLOAT result3=temp.w;\n"
" #else\n"
" FLOAT result0=0;\n"
" FLOAT result1=0;\n"
" FLOAT result2=0;\n"
" FLOAT result3=0;\n"
" #endif\n"
" for (short pos=0; pos<channel_blocks; pos += 1) {\n"
" a=RI_F(input_a,SAMPLER,(int2)(pos,height_idx));\n"
" short remain=(pos+1)*4-channels;\n"
" b0=RI_F(input_b,SAMPLER,(int2)(width_blocks_idx,pos*4));\n"
" b1=RI_F(input_b,SAMPLER,(int2)(width_blocks_idx,pos*4+1));\n"
" b2=RI_F(input_b,SAMPLER,(int2)(width_blocks_idx,pos*4+2));\n"
" b3=RI_F(input_b,SAMPLER,(int2)(width_blocks_idx,pos*4+3));\n"
" if (remain == 3) {\n"
" b1=0;\n"
" b2=0;\n"
" b3=0;\n"
" } else if (remain == 2) {\n"
" b2=0;\n"
" b3=0;\n"
" } else if (remain == 1) {\n"
" b3=0;\n"
" }\n"
" FLOAT4 btmp0=(FLOAT4)(b0.s0,b1.s0,b2.s0,b3.s0);\n"
" FLOAT4 btmp1=(FLOAT4)(b0.s1,b1.s1,b2.s1,b3.s1);\n"
" FLOAT4 btmp2=(FLOAT4)(b0.s2,b1.s2,b2.s2,b3.s2);\n"
" FLOAT4 btmp3=(FLOAT4)(b0.s3,b1.s3,b2.s3,b3.s3);\n"
" result0 += dot(a,btmp0);\n"
" result1 += dot(a,btmp1);\n"
" result2 += dot(a,btmp2);\n"
" result3 += dot(a,btmp3);\n"
" }\n"
" WI_F(output_c,(int2)(width_blocks_idx,height_idx),(FLOAT4)(result0,result1,result2,result3));\n"
"}\n"
"__kernel void matmul_transB(GLOBAL_SIZE_2_DIMS __read_only image2d_t input_a,\n"
" __read_only image2d_t input_b,\n"
" #ifdef BIAS\n"
" __read_only image2d_t input_c,\n"
" #endif\n"
" __write_only image2d_t output_c,__private const int channels,\n"
" __private const int channel_blocks) {\n"
" const int width_blocks_idx=get_global_id(0);\n"
" const int height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(width_blocks_idx,height_idx);\n"
" FLOAT4 a;\n"
" FLOAT4 b0=0,b1=0,b2=0,b3=0;\n"
" #ifdef BIAS\n"
" FLOAT4 temp=RI_F(input_c,SAMPLER,(int2)(width_blocks_idx,0));\n"
" FLOAT result0=temp.x;\n"
" FLOAT result1=temp.y;\n"
" FLOAT result2=temp.z;\n"
" FLOAT result3=temp.w;\n"
" #else\n"
" FLOAT result0=0;\n"
" FLOAT result1=0;\n"
" FLOAT result2=0;\n"
" FLOAT result3=0;\n"
" #endif\n"
" for (short pos=0; pos<channel_blocks; pos += 1) {\n"
" a=RI_F(input_a,SAMPLER,(int2)(pos,height_idx));\n"
" short remain=(pos+1)*4-channels;\n"
" b0=RI_F(input_b,SAMPLER,(int2)(pos,width_blocks_idx*4));\n"
" b1=RI_F(input_b,SAMPLER,(int2)(pos,width_blocks_idx*4+1));\n"
" b2=RI_F(input_b,SAMPLER,(int2)(pos,width_blocks_idx*4+2));\n"
" b3=RI_F(input_b,SAMPLER,(int2)(pos,width_blocks_idx*4+3));\n"
" if (remain == 3) {\n"
" a.y=0;\n"
" a.z=0;\n"
" a.w=0;\n"
" } else if (remain == 2) {\n"
" a.z=0;\n"
" a.w=0;\n"
" } else if (remain == 1) {\n"
" a.w=0;\n"
" }\n"
" result0 += dot(a,b0);\n"
" result1 += dot(a,b1);\n"
" result2 += dot(a,b2);\n"
" result3 += dot(a,b3);\n"
" }\n"
" WI_F(output_c,(int2)(width_blocks_idx,height_idx),(FLOAT4)(result0,result1,result2,result3));\n"
"}\n"
" __kernel void matmul_transA(GLOBAL_SIZE_2_DIMS __read_only image2d_t input_a,\n"
" __read_only image2d_t input_b,\n"
" #ifdef BIAS\n"
" __read_only image2d_t input_c,\n"
" #endif\n"
" __write_only image2d_t output_c,\n"
" __private const int channels,\n"
" __private const int channel_blocks,\n"
" __private const int height) {\n"
" const int width_blocks_idx=get_global_id(0);\n"
" const int height_blocks_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(width_blocks_idx,height_blocks_idx);\n"
" FLOAT4 v_zero=(FLOAT4)((FLOAT)0.0);\n"
" #ifdef BIAS\n"
" FLOAT4 result0=RI_F(input_c,SAMPLER,(int2)(width_blocks_idx,0));\n"
" FLOAT4 result1=result0;\n"
" FLOAT4 result2=result0;\n"
" FLOAT4 result3=result0;\n"
" #else\n"
" FLOAT4 result0=0;\n"
" FLOAT4 result1=0;\n"
" FLOAT4 result2=0;\n"
" FLOAT4 result3=0;\n"
" #endif\n"
" \n"
" for (short pos=0; pos<channel_blocks; pos += 1) {\n"
" FLOAT4 a0=RI_F(input_a,SAMPLER,(int2)(height_blocks_idx,4*pos));\n"
" FLOAT4 a1=RI_F(input_a,SAMPLER,(int2)(height_blocks_idx,4*pos+1));\n"
" FLOAT4 a2=RI_F(input_a,SAMPLER,(int2)(height_blocks_idx,4*pos+2));\n"
" FLOAT4 a3=RI_F(input_a,SAMPLER,(int2)(height_blocks_idx,4*pos+3));\n"
" FLOAT4 b0=RI_F(input_b,SAMPLER,(int2)(width_blocks_idx,4*pos));\n"
" FLOAT4 b1=RI_F(input_b,SAMPLER,(int2)(width_blocks_idx,4*pos+1));\n"
" FLOAT4 b2=RI_F(input_b,SAMPLER,(int2)(width_blocks_idx,4*pos+2));\n"
" FLOAT4 b3=RI_F(input_b,SAMPLER,(int2)(width_blocks_idx,4*pos+3));\n"
" \n"
" short remain=(pos+1)*4-channels;\n"
" a3=((remain >= 1) ? v_zero : a3);\n"
" a2=((remain >= 2) ? v_zero : a2);\n"
" a1=((remain >= 3) ? v_zero : a1);\n"
" FLOAT4 a0_trans=(FLOAT4)(a0.x,a1.x,a2.x,a3.x);\n"
" FLOAT4 a1_trans=(FLOAT4)(a0.y,a1.y,a2.y,a3.y);\n"
" FLOAT4 a2_trans=(FLOAT4)(a0.z,a1.z,a2.z,a3.z);\n"
" FLOAT4 a3_trans=(FLOAT4)(a0.w,a1.w,a2.w,a3.w);\n"
" \n"
" FLOAT4 b0_trans=(FLOAT4)(b0.x,b1.x,b2.x,b3.x);\n"
" FLOAT4 b1_trans=(FLOAT4)(b0.y,b1.y,b2.y,b3.y);\n"
" FLOAT4 b2_trans=(FLOAT4)(b0.z,b1.z,b2.z,b3.z);\n"
" FLOAT4 b3_trans=(FLOAT4)(b0.w,b1.w,b2.w,b3.w);\n"
" //matmul\n"
" result0.x += dot(a0_trans,b0_trans);\n"
" result0.y += dot(a0_trans,b1_trans);\n"
" result0.z += dot(a0_trans,b2_trans);\n"
" result0.w += dot(a0_trans,b3_trans);\n"
" \n"
" result1.x += dot(a1_trans,b0_trans);\n"
" result1.y += dot(a1_trans,b1_trans);\n"
" result1.z += dot(a1_trans,b2_trans);\n"
" result1.w += dot(a1_trans,b3_trans);\n"
" \n"
" result2.x += dot(a2_trans,b0_trans);\n"
" result2.y += dot(a2_trans,b1_trans);\n"
" result2.z += dot(a2_trans,b2_trans);\n"
" result2.w += dot(a2_trans,b3_trans);\n"
" \n"
" result3.x += dot(a3_trans,b0_trans);\n"
" result3.y += dot(a3_trans,b1_trans);\n"
" result3.z += dot(a3_trans,b2_trans);\n"
" result3.w += dot(a3_trans,b3_trans);\n"
" }\n"
" WI_F(output_c,(int2)(width_blocks_idx,4*height_blocks_idx),result0);\n"
" if(4*height_blocks_idx+1 >= height) return;\n"
" WI_F(output_c,(int2)(width_blocks_idx,4*height_blocks_idx+1),result1);\n"
" if(4*height_blocks_idx+2 >= height) return;\n"
" WI_F(output_c,(int2)(width_blocks_idx,4*height_blocks_idx+2),result2);\n"
" if(4*height_blocks_idx+3 >= height) return;\n"
" WI_F(output_c,(int2)(width_blocks_idx,4*height_blocks_idx+3),result3);\n"
"}\n"
"__kernel void matmul_transA_transB(GLOBAL_SIZE_2_DIMS __read_only image2d_t input_a,\n"
" __read_only image2d_t input_b,\n"
" #ifdef BIAS\n"
" __read_only image2d_t input_c,\n"
" #endif\n"
" __write_only image2d_t output_c,\n"
" __private const int channels,\n"
" __private const int channel_blocks,\n"
" __private const int height) {\n"
" const int width_blocks_idx=get_global_id(0);\n"
" const int height_blocks_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(width_blocks_idx,height_blocks_idx);\n"
" FLOAT4 v_zero=(FLOAT4)((FLOAT)0.0);\n"
" #ifdef BIAS\n"
" FLOAT4 result0=RI_F(input_c,SAMPLER,(int2)(width_blocks_idx,0));\n"
" FLOAT4 result1=result0;\n"
" FLOAT4 result2=result0;\n"
" FLOAT4 result3=result0;\n"
" #else\n"
" FLOAT4 result0=0;\n"
" FLOAT4 result1=0;\n"
" FLOAT4 result2=0;\n"
" FLOAT4 result3=0;\n"
" #endif\n"
" for (short pos=0; pos<channel_blocks; pos += 1) {\n"
" FLOAT4 a0=RI_F(input_a,SAMPLER,(int2)(height_blocks_idx,4*pos));\n"
" FLOAT4 a1=RI_F(input_a,SAMPLER,(int2)(height_blocks_idx,4*pos+1));\n"
" FLOAT4 a2=RI_F(input_a,SAMPLER,(int2)(height_blocks_idx,4*pos+2));\n"
" FLOAT4 a3=RI_F(input_a,SAMPLER,(int2)(height_blocks_idx,4*pos+3));\n"
" FLOAT4 b0=RI_F(input_b,SAMPLER,(int2)(pos,4*width_blocks_idx));\n"
" FLOAT4 b1=RI_F(input_b,SAMPLER,(int2)(pos,4*width_blocks_idx+1));\n"
" FLOAT4 b2=RI_F(input_b,SAMPLER,(int2)(pos,4*width_blocks_idx+2));\n"
" FLOAT4 b3=RI_F(input_b,SAMPLER,(int2)(pos,4*width_blocks_idx+3));\n"
" \n"
" short remain=(pos+1)*4-channels;\n"
" a3=((remain >= 1) ? v_zero : a3);\n"
" a2=((remain >= 2) ? v_zero : a2);\n"
" a1=((remain >= 3) ? v_zero : a1);\n"
" FLOAT4 a0_trans=(FLOAT4)(a0.x,a1.x,a2.x,a3.x);\n"
" FLOAT4 a1_trans=(FLOAT4)(a0.y,a1.y,a2.y,a3.y);\n"
" FLOAT4 a2_trans=(FLOAT4)(a0.z,a1.z,a2.z,a3.z);\n"
" FLOAT4 a3_trans=(FLOAT4)(a0.w,a1.w,a2.w,a3.w);\n"
" //matmul\n"
" result0.x += dot(a0_trans,b0);\n"
" result0.y += dot(a0_trans,b1);\n"
" result0.z += dot(a0_trans,b2);\n"
" result0.w += dot(a0_trans,b3);\n"
" \n"
" result1.x += dot(a1_trans,b0);\n"
" result1.y += dot(a1_trans,b1);\n"
" result1.z += dot(a1_trans,b2);\n"
" result1.w += dot(a1_trans,b3);\n"
" \n"
" result2.x += dot(a2_trans,b0);\n"
" result2.y += dot(a2_trans,b1);\n"
" result2.z += dot(a2_trans,b2);\n"
" result2.w += dot(a2_trans,b3);\n"
" \n"
" result3.x += dot(a3_trans,b0);\n"
" result3.y += dot(a3_trans,b1);\n"
" result3.z += dot(a3_trans,b2);\n"
" result3.w += dot(a3_trans,b3);\n"
" }\n"
" WI_F(output_c,(int2)(width_blocks_idx,4*height_blocks_idx),result0);\n"
" if(4*height_blocks_idx+1 >= height) return;\n"
" WI_F(output_c,(int2)(width_blocks_idx,4*height_blocks_idx+1),result1);\n"
" if(4*height_blocks_idx+2 >= height) return;\n"
" WI_F(output_c,(int2)(width_blocks_idx,4*height_blocks_idx+2),result2);\n"
" if(4*height_blocks_idx+3 >= height) return;\n"
" WI_F(output_c,(int2)(width_blocks_idx,4*height_blocks_idx+3),result3);\n"
"}\n"
;
const char* binary = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define PI 3.141592653589f\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void binary(__private int global_dim0,__private int global_dim1,\n"
" __read_only image2d_t input0,__read_only image2d_t input1,\n"
" __write_only image2d_t output,\n"
" __private const int4 shape,//[N,H,W,C4]\n"
" __private const int2 isFull,\n"
" __private const int activationType) {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1));//WC4,NH\n"
" \n"
" float4 in0,in1;\n"
" if (pos.x<global_dim0 && pos.y<global_dim1) {\n"
" if(isFull.x == 0) {\n"
" in0=convert_float4(RI_DATA(input0,SAMPLER,(int2)(0,0)));\n"
" in0=(float4)(in0.x,in0.x,in0.x,in0.x);\n"
" } else {\n"
" in0=convert_float4(RI_DATA(input0,SAMPLER,pos));\n"
" }\n"
" if(isFull.y == 0) {\n"
" in1=convert_float4(RI_DATA(input1,SAMPLER,(int2)(0,0)));\n"
" in1=(float4)(in1.x,in1.x,in1.x,in1.x);\n"
" } else {\n"
" in1=convert_float4(RI_DATA(input1,SAMPLER,pos));\n"
" }\n"
" \n"
" float4 out=OPERATOR;\n"
" \n"
" if(activationType == 1) {\n"
" out=fmax(out,(float4)0);\n"
" }\n"
" WI_DATA(output,pos,CONVERT_OUTPUT_I4(out));\n"
" }\n"
"}\n"
"__kernel void binary_prelu(__read_only image2d_t input0,__read_only image2d_t input1,__write_only image2d_t output,\n"
" int4 shape,int2 whInput1,int4 input1NHWCStep) {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1));\n"
" int4 nhwc=(int4)(pos.y/shape.y,pos.y%shape.y,pos.x%shape.z,pos.x/shape.z);\n"
" if (nhwc.x<shape.x && nhwc.w<shape.w) {\n"
" int4 nhwc1=nhwc*input1NHWCStep;\n"
" int2 pos1=(int2)(nhwc1.w*whInput1.x+nhwc1.z,nhwc1.x*whInput1.y+nhwc1.y);\n"
" float4 in0=convert_float4(RI_DATA(input0,SAMPLER,pos));\n"
" float4 in1=convert_float4(RI_DATA(input1,SAMPLER,pos1));\n"
" OUTPUT_TYPE_I4 out=CONVERT_OUTPUT_I4(OPERATOR);\n"
" WI_DATA(output,pos,out);\n"
" }\n"
"}\n"
"__kernel void imageCopy(__read_only image2d_t input,__write_only image2d_t output) {\n"
" const int2 pos=(int2)(get_global_id(0),get_global_id(1));\n"
" const int2 dim=get_image_dim(input);\n"
" if (pos.x >= dim.x && pos.y >= dim.y) {\n"
" return;\n"
" }\n"
" WI_DATA(output,pos,CONVERT_OUTPUT_I4(RI_DATA(input,SAMPLER,pos)));\n"
"}\n"
;
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* loop_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define PI 3.141592653589f\n"
"#ifndef WGSW\n"
" #define WGSW 32 // work-group handle size W dimension\n"
"#endif\n"
"#ifndef WGSC\n"
" #define WGSC 32 // work-group handle size C dimension\n"
"#endif\n"
"#ifndef WGSH\n"
" #define WGSH 32 // work-group handle size H dimension\n"
"#endif\n"
"#ifndef TSW\n"
" #define TSW 8 // thread handle size W dimension\n"
"#endif\n"
"#ifndef TSC\n"
" #define TSC 8 // thread handle size C dimension\n"
"#endif\n"
"#ifndef TSH\n"
" #define TSH 8 // thread handle size H dimension\n"
"#endif\n"
"// [C4 N H 1 4] -> [N H C 1]\n"
"__kernel void tile_trans_3d_buf(__global INPUT_TYPE* input,\n"
" __global OUTPUT_TYPE* output,\n"
" __private const int widthPad,\n"
" __private const int heightPad,\n"
" __private const int channelPad,\n"
" __private const int batch,\n"
" __private const int width,\n"
" __private const int height,\n"
" __private const int channel\n"
") {\n"
" int b=get_global_id(2);\n"
" \n"
" const int lidc=get_local_id(0);\n"
" const int lidh=get_local_id(1);\n"
" // group id\n"
" const int c=get_group_id(0)*WGSC;\n"
" const int h=get_group_id(1)*WGSH;\n"
" int jc=lidc;\n"
" int ih=lidh;\n"
" \n"
" __local INPUT_TYPE4 localData[WGSH][WGSC/4];//h64c64\n"
" \n"
" #pragma unroll\n"
" for(int i=0; i<TSH; i++) {\n"
" #pragma unroll\n"
" for(int j=0; j<TSC/4; j++) {\n"
" int offset_h=i*WGSH/TSH+ih;\n"
" int offset_c=j*WGSC/TSC+jc ;\n"
" // [TSH,WGSH/TSH] [TSC/4,WGSC/TSC,4]\n"
" localData[offset_h][offset_c]=(h+offset_h >= height || c+4*offset_c >= channel) ? (INPUT_TYPE4)0 : vload4(0,input+((b+(c/4+offset_c)*batch)*height+(h+offset_h))*4);\n"
" }\n"
" }\n"
" \n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" \n"
" // C offset: [WGSC/TSC,TSC/4]\n"
" // H offset: [WGSH/TSH,TSH]\n"
" int oc_base=jc*TSC/4;\n"
" int oh_base=ih*TSH;\n"
" //#pragma unroll\n"
" for(int i=0; i<TSH; i++) {\n"
" int oh=oh_base+i;\n"
" //#pragma unroll\n"
" for(int j=0; j<TSC/4; j++) {\n"
" int oc=oc_base+j;\n"
" \n"
" OUTPUT_TYPE4 value=CONVERT_OUTPUT4(localData[oh][oc]);\n"
" vstore4(value,0,output+((b*heightPad+h+oh)*channelPad+c+4*oc));\n"
" }\n"
" }\n"
"}\n"
"// [C4 N H W 4] -> [N C W H]\n"
"__kernel void tile_trans_4d_buf(__global INPUT_TYPE* input,\n"
" __global OUTPUT_TYPE* output,\n"
" __private const int widthPad,\n"
" __private const int heightPad,\n"
" __private const int channelPad,\n"
" __private const int batch,\n"
" __private const int width,\n"
" __private const int height,\n"
" __private const int channel\n"
") {\n"
" int bc=get_global_id(2);\n"
" int b=bc % batch;\n"
" int c4=bc/batch;\n"
" int c=c4 << 2;\n"
" \n"
" const int lidw=get_local_id(0);\n"
" const int lidh=get_local_id(1);\n"
" // group id\n"
" const int w=get_group_id(0)*WGSW;\n"
" const int h=get_group_id(1)*WGSH;\n"
" int jw=lidw;\n"
" int ih=lidh;\n"
" \n"
" __local INPUT_TYPE4 localData[WGSH][WGSW];//w32h32c4\n"
" \n"
" #pragma unroll\n"
" for(int i=0; i<TSH; i++) {\n"
" #pragma unroll\n"
" for(int j=0; j<TSW; j++) {\n"
" int offset_h=h+ih+i*WGSH/TSH;\n"
" int offset_w=w+jw+j*WGSW/TSW;\n"
" localData[ih+i*WGSH/TSH][jw+j*WGSW/TSW]=(offset_h >= height || offset_w >= width) ? (INPUT_TYPE4)0 : vload4(0,input+(((b+c4*batch)*height+offset_h)*width+offset_w)*4);\n"
" }\n"
" }\n"
" \n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" \n"
" // c4w32h32\n"
" int oh=ih*TSH >> 4;\n"
" int mh=ih & (16/TSH-1);\n"
" // TSW offset: [TSH/4,TSW/4,16/TSH]\n"
" int ow_base=jw*TSW;\n"
" int oh_offset=oh << 4;\n"
" //#pragma unroll\n"
" for(int i=0; i<TSH/4; i++) {\n"
" //#pragma unroll\n"
" for(int j=0; j<TSW/4; j++) {\n"
" \n"
" // c4\n"
" OUTPUT_TYPE16 value;\n"
" int ow=ow_base+(((i*TSW/4)+j)*(16/TSH)+mh);\n"
" \n"
" value.s0=localData[0+oh_offset][ow].s0;\n"
" value.s1=localData[1+oh_offset][ow].s0;\n"
" value.s2=localData[2+oh_offset][ow].s0;\n"
" value.s3=localData[3+oh_offset][ow].s0;\n"
" value.s4=localData[4+oh_offset][ow].s0;\n"
" value.s5=localData[5+oh_offset][ow].s0;\n"
" value.s6=localData[6+oh_offset][ow].s0;\n"
" value.s7=localData[7+oh_offset][ow].s0;\n"
" value.s8=localData[8+oh_offset][ow].s0;\n"
" value.s9=localData[9+oh_offset][ow].s0;\n"
" value.sa=localData[10+oh_offset][ow].s0;\n"
" value.sb=localData[11+oh_offset][ow].s0;\n"
" value.sc=localData[12+oh_offset][ow].s0;\n"
" value.sd=localData[13+oh_offset][ow].s0;\n"
" value.se=localData[14+oh_offset][ow].s0;\n"
" value.sf=localData[15+oh_offset][ow].s0;\n"
" vstore16(value,0,output+(((b*channelPad+c+0)*widthPad+w+ow)*heightPad+h+oh_offset));\n"
" \n"
" if(c+1<channel) {\n"
" value.s0=localData[0+oh_offset][ow].s1;\n"
" value.s1=localData[1+oh_offset][ow].s1;\n"
" value.s2=localData[2+oh_offset][ow].s1;\n"
" value.s3=localData[3+oh_offset][ow].s1;\n"
" value.s4=localData[4+oh_offset][ow].s1;\n"
" value.s5=localData[5+oh_offset][ow].s1;\n"
" value.s6=localData[6+oh_offset][ow].s1;\n"
" value.s7=localData[7+oh_offset][ow].s1;\n"
" value.s8=localData[8+oh_offset][ow].s1;\n"
" value.s9=localData[9+oh_offset][ow].s1;\n"
" value.sa=localData[10+oh_offset][ow].s1;\n"
" value.sb=localData[11+oh_offset][ow].s1;\n"
" value.sc=localData[12+oh_offset][ow].s1;\n"
" value.sd=localData[13+oh_offset][ow].s1;\n"
" value.se=localData[14+oh_offset][ow].s1;\n"
" value.sf=localData[15+oh_offset][ow].s1;\n"
" vstore16(value,0,output+(((b*channelPad+c+1)*widthPad+w+ow)*heightPad+h+oh_offset));\n"
" }\n"
" \n"
" if(c+2<channel) {\n"
" value.s0=localData[0+oh_offset][ow].s2;\n"
" value.s1=localData[1+oh_offset][ow].s2;\n"
" value.s2=localData[2+oh_offset][ow].s2;\n"
" value.s3=localData[3+oh_offset][ow].s2;\n"
" value.s4=localData[4+oh_offset][ow].s2;\n"
" value.s5=localData[5+oh_offset][ow].s2;\n"
" value.s6=localData[6+oh_offset][ow].s2;\n"
" value.s7=localData[7+oh_offset][ow].s2;\n"
" value.s8=localData[8+oh_offset][ow].s2;\n"
" value.s9=localData[9+oh_offset][ow].s2;\n"
" value.sa=localData[10+oh_offset][ow].s2;\n"
" value.sb=localData[11+oh_offset][ow].s2;\n"
" value.sc=localData[12+oh_offset][ow].s2;\n"
" value.sd=localData[13+oh_offset][ow].s2;\n"
" value.se=localData[14+oh_offset][ow].s2;\n"
" value.sf=localData[15+oh_offset][ow].s2;\n"
" vstore16(value,0,output+(((b*channelPad+c+2)*widthPad+w+ow)*heightPad+h+oh_offset));\n"
" }\n"
" \n"
" if(c+3<channel) {\n"
" value.s0=localData[0+oh_offset][ow].s3;\n"
" value.s1=localData[1+oh_offset][ow].s3;\n"
" value.s2=localData[2+oh_offset][ow].s3;\n"
" value.s3=localData[3+oh_offset][ow].s3;\n"
" value.s4=localData[4+oh_offset][ow].s3;\n"
" value.s5=localData[5+oh_offset][ow].s3;\n"
" value.s6=localData[6+oh_offset][ow].s3;\n"
" value.s7=localData[7+oh_offset][ow].s3;\n"
" value.s8=localData[8+oh_offset][ow].s3;\n"
" value.s9=localData[9+oh_offset][ow].s3;\n"
" value.sa=localData[10+oh_offset][ow].s3;\n"
" value.sb=localData[11+oh_offset][ow].s3;\n"
" value.sc=localData[12+oh_offset][ow].s3;\n"
" value.sd=localData[13+oh_offset][ow].s3;\n"
" value.se=localData[14+oh_offset][ow].s3;\n"
" value.sf=localData[15+oh_offset][ow].s3;\n"
" vstore16(value,0,output+(((b*channelPad+c+3)*widthPad+w+ow)*heightPad+h+oh_offset));\n"
" }\n"
" }\n"
" }\n"
"}\n"
"__kernel void tile_buf(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __global INPUT_TYPE* input,__global OUTPUT_TYPE* output,\n"
" __private const int widthPad,\n"
" __private const int heightPad,\n"
" __private const int channelPad,\n"
" __private const int batch,\n"
" __private const int width,\n"
" __private const int height,\n"
" __private const int channel){\n"
" int3 pos=(int3)(get_global_id(0),get_global_id(1),get_global_id(2));\n"
" if (pos.x<global_dim0 && pos.y<global_dim1 && pos.z<global_dim2) {\n"
" const int b=pos.z % batch;\n"
" const int w=pos.x;\n"
" const int h=pos.y;\n"
" const int c_4=pos.z/batch;\n"
" \n"
" const int c=c_4 << 2;\n"
" const int x_src_pitch=4;\n"
" const int y_src_pitch=x_src_pitch*width;\n"
" const int b_src_pitch=y_src_pitch*height;\n"
" const int c_src_pitch=b_src_pitch*batch;\n"
" \n"
" bool outBound=(w >= width || h >= height || c >= channel);\n"
"#ifdef MNN_NHWC\n"
" #if defined(DIMENSION_3) && defined(TRANSPOSE)\n"
" // [N,W,H,1]\n"
" const int c_dst_pitch=1;\n"
" const int y_dst_pitch=c_dst_pitch*channelPad;\n"
" const int x_dst_pitch=y_dst_pitch*heightPad;\n"
" const int b_dst_pitch=x_dst_pitch*widthPad;\n"
" OUTPUT_TYPE4 value=outBound ? (OUTPUT_TYPE4)0 : CONVERT_OUTPUT4(vload4(0,input+b*b_src_pitch+c_4*c_src_pitch+h*y_src_pitch+w*x_src_pitch));\n"
" #elif defined(DIMENSION_4) && defined(TRANSPOSE)\n"
" // [N,H,C,W]\n"
" const int x_dst_pitch=1;\n"
" const int c_dst_pitch=x_dst_pitch*widthPad;\n"
" const int y_dst_pitch=c_dst_pitch*channelPad;\n"
" const int b_dst_pitch=y_dst_pitch*heightPad;\n"
" OUTPUT_TYPE4 value=outBound ? (OUTPUT_TYPE4)0 : CONVERT_OUTPUT4(vload4(0,input+b*b_src_pitch+c_4*c_src_pitch+h*y_src_pitch+w*x_src_pitch));\n"
" #elif defined(DIMENSION_3)\n"
" // [N,H,W,1]\n"
" const int c_dst_pitch=1;\n"
" const int x_dst_pitch=c_dst_pitch*channelPad;\n"
" const int y_dst_pitch=x_dst_pitch*widthPad;\n"
" const int b_dst_pitch=y_dst_pitch*heightPad;\n"
" OUTPUT_TYPE4 value=outBound ? (OUTPUT_TYPE4)0 : CONVERT_OUTPUT4(vload4(0,input+b*b_src_pitch+c_4*c_src_pitch+h*y_src_pitch+w*x_src_pitch));\n"
" #else\n"
" // [N,H,W,C]\n"
" const int c_dst_pitch=1;\n"
" const int x_dst_pitch=c_dst_pitch*channelPad;\n"
" const int y_dst_pitch=x_dst_pitch*widthPad;\n"
" const int b_dst_pitch=y_dst_pitch*heightPad;\n"
" OUTPUT_TYPE4 value=outBound ? (OUTPUT_TYPE4)0 : CONVERT_OUTPUT4(vload4(0,input+b*b_src_pitch+c_4*c_src_pitch+h*y_src_pitch+w*x_src_pitch));\n"
" #endif\n"
"#else\n"
" #if defined(DIMENSION_3) && defined(TRANSPOSE)\n"
" // [N,H,C,1]\n"
" const int x_dst_pitch=1;\n"
" const int c_dst_pitch=x_dst_pitch*widthPad;\n"
" const int y_dst_pitch=c_dst_pitch*channelPad;\n"
" const int b_dst_pitch=y_dst_pitch*heightPad;\n"
" OUTPUT_TYPE4 value=outBound ? (OUTPUT_TYPE4)0 : CONVERT_OUTPUT4(vload4(0,input+b*b_src_pitch+c_4*c_src_pitch+h*y_src_pitch+w*x_src_pitch));\n"
" \n"
" #elif defined(DIMENSION_4) && defined(TRANSPOSE)\n"
" // [N,C,W,H]\n"
" const int y_dst_pitch=1;\n"
" const int x_dst_pitch=y_dst_pitch*heightPad;\n"
" const int c_dst_pitch=x_dst_pitch*widthPad;\n"
" const int b_dst_pitch=c_dst_pitch*channelPad;\n"
" OUTPUT_TYPE4 value=outBound ? (OUTPUT_TYPE4)0 : CONVERT_OUTPUT4(vload4(0,input+b*b_src_pitch+c_4*c_src_pitch+h*y_src_pitch+w*x_src_pitch));\n"
" #elif defined(DIMENSION_3)\n"
" // [N,C,H,1]\n"
" const int x_dst_pitch=1;\n"
" const int y_dst_pitch=x_dst_pitch*widthPad;\n"
" const int c_dst_pitch=y_dst_pitch*heightPad;\n"
" const int b_dst_pitch=c_dst_pitch*channelPad;\n"
" OUTPUT_TYPE4 value=outBound ? (OUTPUT_TYPE4)0 : CONVERT_OUTPUT4(vload4(0,input+b*b_src_pitch+c_4*c_src_pitch+h*y_src_pitch+w*x_src_pitch));\n"
" #else\n"
" // [N,C,H,W]\n"
" const int x_dst_pitch=1;\n"
" const int y_dst_pitch=x_dst_pitch*widthPad;\n"
" const int c_dst_pitch=y_dst_pitch*heightPad;\n"
" const int b_dst_pitch=c_dst_pitch*channelPad;\n"
" OUTPUT_TYPE4 value=outBound ? (OUTPUT_TYPE4)0 : CONVERT_OUTPUT4(vload4(0,input+b*b_src_pitch+c_4*c_src_pitch+h*y_src_pitch+w*x_src_pitch));\n"
" #endif\n"
"#endif\n"
" __global OUTPUT_TYPE* dst_ptr=output+b*b_dst_pitch+c*c_dst_pitch+h*y_dst_pitch+w*x_dst_pitch;\n"
" dst_ptr[0]=value.x;\n"
" if(c+1 >= channel)return;\n"
" dst_ptr[c_dst_pitch]=value.y;\n"
" if(c+2 >= channel)return;\n"
" dst_ptr[2*c_dst_pitch]=value.z;\n"
" if(c+3 >= channel)return;\n"
" dst_ptr[3*c_dst_pitch]=value.w;\n"
" }\n"
"}\n"
"__kernel void pack_buf(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __global INPUT_TYPE* input,__global OUTPUT_TYPE* output,\n"
" __private const int widthPad,\n"
" __private const int heightPad,\n"
" __private const int channelPad,\n"
" __private const int batch,\n"
" __private const int width,\n"
" __private const int height,\n"
" __private const int channel){\n"
" int3 pos=(int3)(get_global_id(0),get_global_id(1),get_global_id(2));\n"
" if (pos.x<global_dim0 && pos.y<global_dim1 && pos.z<global_dim2) {\n"
" \n"
" const int b=pos.z % batch;\n"
" const int w=pos.x;\n"
" const int h=pos.y;\n"
" const int c_4=pos.z/batch;\n"
" \n"
" const int c=c_4 << 2;\n"
" if(w >= width || h >= height || c >= channel) {\n"
" return;\n"
" }\n"
" const int x_dst_pitch=4;\n"
" const int y_dst_pitch=x_dst_pitch*width;\n"
" const int c_dst_pitch=y_dst_pitch*height;\n"
" const int b_dst_pitch=c_dst_pitch*((channel+3)/4);\n"
"#ifdef MNN_NHWC\n"
" #if defined(TRANSPOSE) && defined(DIMENSION_3)\n"
" // [N,W,H,1]\n"
" const int c_src_pitch=1;\n"
" const int y_src_pitch=c_src_pitch;\n"
" const int x_src_pitch=y_src_pitch*heightPad;\n"
" const int b_src_pitch=x_src_pitch*widthPad;\n"
" #elif defined(TRANSPOSE) && defined(DIMENSION_4)\n"
" // [N,H,C,W]\n"
" const int x_src_pitch=1;\n"
" const int c_src_pitch=x_src_pitch*widthPad;\n"
" const int y_src_pitch=c_src_pitch*channelPad;\n"
" const int b_src_pitch=y_src_pitch*heightPad;\n"
" #else\n"
" // [N,H,W,C]\n"
" const int c_src_pitch=1;\n"
" const int x_src_pitch=c_src_pitch*channelPad;\n"
" const int y_src_pitch=x_src_pitch*widthPad;\n"
" const int b_src_pitch=y_src_pitch*heightPad;\n"
" #endif\n"
"#else\n"
" #if defined(TRANSPOSE) && defined(DIMENSION_3)\n"
" // dst:[N,C,H,1] -> src:[N,H,C,1]\n"
" const int x_src_pitch=1;\n"
" const int c_src_pitch=x_src_pitch*widthPad;\n"
" const int y_src_pitch=c_src_pitch*channelPad;\n"
" const int b_src_pitch=y_src_pitch*heightPad;\n"
" #elif defined(TRANSPOSE) && defined(DIMENSION_4)\n"
" // dst:[N,C,H,W] -> src:[N,C,W,H]\n"
" const int y_src_pitch=1;\n"
" const int x_src_pitch=y_src_pitch*heightPad;\n"
" const int c_src_pitch=x_src_pitch*widthPad;\n"
" const int b_src_pitch=c_src_pitch*channelPad;\n"
" #else\n"
" // [N,C,H,W]\n"
" const int x_src_pitch=1;\n"
" const int y_src_pitch=x_src_pitch*widthPad;\n"
" const int c_src_pitch=y_src_pitch*heightPad;\n"
" const int b_src_pitch=c_src_pitch*channelPad;\n"
" #endif\n"
"#endif\n"
" __global INPUT_TYPE* src_ptr=input+b*b_src_pitch+c*c_src_pitch+h*y_src_pitch+w*x_src_pitch;\n"
" OUTPUT_TYPE4 value=(OUTPUT_TYPE4)0;\n"
" OUTPUT_TYPE *value_ptr=(OUTPUT_TYPE*)&value;\n"
" for(int i=0; i<4 && (i+c<channel); ++i){\n"
" value_ptr[i]=(OUTPUT_TYPE)src_ptr[i*c_src_pitch];\n"
" }\n"
" vstore4(value,0,output+b*b_dst_pitch+c_4*c_dst_pitch+h*y_dst_pitch+w*x_dst_pitch);\n"
" }\n"
"}\n"
"#ifdef LOOP_BINARY_OPERATOR\n"
"__kernel void loop_binary_buf(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __global OUTPUT_TYPE* output,__global INPUT_TYPE* input0,__global INPUT_TYPE* input1,\n"
" __private const int input0Stride0,\n"
" __private const int input0Stride1,\n"
" __private const int input0Stride2,\n"
" __private const int input1Stride0,\n"
" __private const int input1Stride1,\n"
" __private const int input1Stride2,\n"
" __private const int outputStride0,\n"
" __private const int outputStride1,\n"
" __private const int outputStride2\n"
" ) {\n"
" \n"
" const int x=get_global_id(0);\n"
" const int y=get_global_id(1);\n"
" const int z=get_global_id(2);\n"
" \n"
" if (x<global_dim0 && y<global_dim1 && z<global_dim2) {\n"
" \n"
" int inputIndex0=z*input0Stride0+y*input0Stride1+x*input0Stride2;\n"
" int inputIndex1=z*input1Stride0+y*input1Stride1+x*input1Stride2;\n"
" int outputIndex=z*outputStride0+y*outputStride1+x*outputStride2;\n"
" float in0=(float)input0[inputIndex0];\n"
" float in1=(float)input1[inputIndex1];\n"
" float out=LOOP_BINARY_OPERATOR;\n"
" output[outputIndex]=(OUTPUT_TYPE)out;\n"
" }\n"
"}\n"
"#endif\n"
;
#endif
const char* roi_pooling = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"#define MIN_VALUE -FLT_MAX\n"
"// Supported data type: half/float\n"
"__kernel void roi_pooling(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,__read_only image2d_t roi,\n"
" __private const int in_height,__private const int in_width,__private const int in_batch,\n"
" __private const int out_height,__private const int out_width,__private const float spatial_scale,\n"
" __write_only image2d_t output) {\n"
" const int out_channel_idx=get_global_id(0);\n"
" const int out_width_idx=get_global_id(1);\n"
" const int out_hb_idx=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(out_channel_idx,out_width_idx,out_hb_idx);\n"
" const int roi_batch_idx=out_hb_idx/out_height;\n"
" const int out_height_idx=out_hb_idx % out_height;\n"
"#if defined ROI_C1H1W5\n"
" FLOAT4 roi_0=RI_F(roi,SAMPLER,(int2)(0,roi_batch_idx));\n"
" int input_batch=roi_0.x;\n"
" if(input_batch >= in_batch){\n"
" return;\n"
" }\n"
" FLOAT4 roi_1=RI_F(roi,SAMPLER,(int2)(1,roi_batch_idx));\n"
" FLOAT4 roi_2=RI_F(roi,SAMPLER,(int2)(2,roi_batch_idx));\n"
" FLOAT4 roi_3=RI_F(roi,SAMPLER,(int2)(3,roi_batch_idx));\n"
" FLOAT4 roi_4=RI_F(roi,SAMPLER,(int2)(4,roi_batch_idx));\n"
" int x1=round(roi_1.x*spatial_scale);\n"
" int y1=round(roi_2.x*spatial_scale);\n"
" int x2=round(roi_3.x*spatial_scale);\n"
" int y2=round(roi_4.x*spatial_scale);\n"
"#elif defined ROI_C5H1W1\n"
" FLOAT4 roi_0=RI_F(roi,SAMPLER,(int2)(0,roi_batch_idx));\n"
" int input_batch=roi_0.x;\n"
" if(input_batch >= in_batch){\n"
" return;\n"
" }\n"
" FLOAT4 roi_1=RI_F(roi,SAMPLER,(int2)(1,roi_batch_idx));\n"
" int x1=round(roi_0.y*spatial_scale);\n"
" int y1=round(roi_0.z*spatial_scale);\n"
" int x2=round(roi_0.w*spatial_scale);\n"
" int y2=round(roi_1.x*spatial_scale);\n"
"#else\n"
" const int roi_batch_offset=roi_batch_idx*5;\n"
" FLOAT4 roi_0=RI_F(roi,SAMPLER,(int2)(0,roi_batch_offset));\n"
" int input_batch=roi_0.x;\n"
" if(input_batch >= in_batch){\n"
" return;\n"
" }\n"
" FLOAT4 roi_1=RI_F(roi,SAMPLER,(int2)(0,roi_batch_offset+1));\n"
" FLOAT4 roi_2=RI_F(roi,SAMPLER,(int2)(0,roi_batch_offset+2));\n"
" FLOAT4 roi_3=RI_F(roi,SAMPLER,(int2)(0,roi_batch_offset+3));\n"
" FLOAT4 roi_4=RI_F(roi,SAMPLER,(int2)(0,roi_batch_offset+4));\n"
" int x1=round(roi_1.x*spatial_scale);\n"
" int y1=round(roi_2.x*spatial_scale);\n"
" int x2=round(roi_3.x*spatial_scale);\n"
" int y2=round(roi_4.x*spatial_scale);\n"
"#endif\n"
" int roiW=max(x2-x1+1,1);\n"
" int roiH=max(y2-y1+1,1);\n"
" float binSizeW=(float)roiW/(float)out_width;\n"
" float binSizeH=(float)roiH/(float)out_height;\n"
" int hStart=min(max(y1+(int)floor(out_height_idx*binSizeH),0),in_height);\n"
" int hEnd=min(max(y1+(int)ceil((out_height_idx+1)*binSizeH),0),in_height);\n"
" int hLen=hEnd-hStart;\n"
" int wStart=min(max(x1+(int)floor(out_width_idx*binSizeW),0),in_width);\n"
" int wEnd=min(max(x1+(int)ceil((out_width_idx+1)*binSizeW),0),in_width);\n"
" int wLen=wEnd-wStart;\n"
" const int pos=mad24(out_channel_idx,out_width,out_width_idx);\n"
" const FLOAT4 zero_vec=(FLOAT4)(0);\n"
" if (wLen <= 0 || hLen <= 0) {\n"
" WI_F(output,(int2)(pos,out_hb_idx),zero_vec);\n"
" return;\n"
" }\n"
" FLOAT4 res=(FLOAT4)(MIN_VALUE);\n"
" const int in_height_start=hStart;\n"
" const int in_width_start=wStart;\n"
" const int in_channel_offset=mul24(out_channel_idx,in_width);\n"
" const int in_height_offset=mul24(input_batch,in_height);\n"
" const int batch_idx=mul24(input_batch,in_height);\n"
" for (int height=0; height<hLen; ++height) {\n"
" int in_height_idx=in_height_start+height;\n"
" for (int width=0; width<wLen; ++width) {\n"
" int in_width_idx=in_width_start+width;\n"
" FLOAT4 in=RI_F(input,SAMPLER,(int2)(in_channel_offset+in_width_idx,in_height_offset+in_height_idx));\n"
" res=fmax(res,in);\n"
" }\n"
" }\n"
" WI_F(output,(int2)(pos,out_hb_idx),res);\n"
"}\n"
;
const char* depthwise_conv2d = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define READ_INPUT_IMAGE(i, base) "" int inOffset##i = inWidthOffset##i + base; "" inOffset##i = "" select(inCurIdx + inOffset##i, -1, (inOffset##i < 0 || inOffset##i >= inputShape.y)); "" inValue##i=RI_F(input,SAMPLER,(int2)(inOffset##i,inHeightIdx));\n"
"#define CALCULATE_OUTPUT(i) "" outValue##i = mad(inValue##i.x, weights0, outValue##i); "" outValue##i = mad(inValue##i.y, weights1, outValue##i); "" outValue##i = mad(inValue##i.z, weights2, outValue##i); "" outValue##i=mad(inValue##i.w,weights3,outValue##i);\n"
"#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0,__private const int global_size_dim1,\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n"
"__kernel\n"
"#if SET_ATTRIBUTE\n"
"__attribute__((work_group_size_hint(16,16,1)))\n"
"#endif\n"
"void depthwise_conv2d_s1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,__read_only image2d_t filter,\n"
" #ifndef NO_BIAS\n"
" __read_only image2d_t bias,\n"
" #endif\n"
" __write_only image2d_t output,\n"
" __private const int2 inputShape,\n"
" __private const int inChannelBlocks,\n"
" __private const int2 outputShape,\n"
" __private const int2 filterShape,\n"
" __private const int2 paddingShape) {\n"
" const int outChannelWidthIdx=get_global_id(0);\n"
" const int outHeightBlockIdx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(outChannelWidthIdx,outHeightBlockIdx);\n"
" int ow4=(outputShape.y+3)/4;\n"
" const int outChannelBlockIdx=outChannelWidthIdx/ow4;\n"
" const int outWidthBlockidx=outChannelWidthIdx % ow4;\n"
" const int inChannelBlockIdx=outChannelBlockIdx;\n"
" #ifndef NO_BIAS\n"
" FLOAT4 outValue0=RI_F(bias,SAMPLER,(int2)(outChannelBlockIdx,0));\n"
" #else\n"
" FLOAT4 outValue0=(FLOAT4)(0.0f);\n"
" #endif\n"
" FLOAT4 outValue1=outValue0;\n"
" FLOAT4 outValue2=outValue0;\n"
" FLOAT4 outValue3=outValue0;\n"
" const int outWidthBlockidx4=outWidthBlockidx << 2;\n"
" const int inWidthOffset0=outWidthBlockidx4-paddingShape.y;\n"
" const int inWidthOffset1=inWidthOffset0+1;\n"
" const int inWidthOffset2=inWidthOffset0+2;\n"
" const int inWidthOffset3=inWidthOffset0+3;\n"
" int heightIdx=outHeightBlockIdx % outputShape.x-paddingShape.x;\n"
" const int outBatchIdx=mul24((outHeightBlockIdx/outputShape.x),inputShape.x);\n"
" const int inCurIdx=mul24(inChannelBlockIdx,inputShape.y);\n"
" const int inWidthIdx0=select(inCurIdx+inWidthOffset0,-1,(inWidthOffset0<0 || inWidthOffset0 >= inputShape.y));\n"
" const int inWidthIdx1=select(inCurIdx+inWidthOffset1,-1,(inWidthOffset1<0 || inWidthOffset1 >= inputShape.y));\n"
" const int inWidthIdx2=select(inCurIdx+inWidthOffset2,-1,(inWidthOffset2<0 || inWidthOffset2 >= inputShape.y));\n"
" FLOAT4 inValue0,inValue1,inValue2,inValue3;\n"
" for (int kh=0; kh<filterShape.x; kh++) {\n"
" int inHeightIdx=select(heightIdx+outBatchIdx,-1,(heightIdx<0 || heightIdx >= inputShape.x));\n"
" heightIdx++;\n"
" inValue1=RI_F(input,SAMPLER,(int2)(inWidthIdx0,inHeightIdx));\n"
" inValue2=RI_F(input,SAMPLER,(int2)(inWidthIdx1,inHeightIdx));\n"
" inValue3=RI_F(input,SAMPLER,(int2)(inWidthIdx2,inHeightIdx));\n"
" for (int kw=0; kw<filterShape.y; kw++) {\n"
" int filterIdx=mad24(kh,filterShape.y,kw);\n"
" inValue0=inValue1;\n"
" inValue1=inValue2;\n"
" inValue2=inValue3;\n"
" int inWidthIdx=inWidthOffset3+kw;\n"
" inWidthIdx=select(inCurIdx+inWidthIdx,-1,(inWidthIdx<0 || inWidthIdx >= inputShape.y));\n"
" inValue3=RI_F(input,SAMPLER,(int2)(inWidthIdx,inHeightIdx));\n"
" FLOAT4 weights=RI_F(filter,SAMPLER,(int2)(filterIdx,inChannelBlockIdx));\n"
" outValue0=mad(inValue0,weights,outValue0);\n"
" outValue1=mad(inValue1,weights,outValue1);\n"
" outValue2=mad(inValue2,weights,outValue2);\n"
" outValue3=mad(inValue3,weights,outValue3);\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" outValue0=fmax(outValue0,(FLOAT4)0);\n"
" outValue1=fmax(outValue1,(FLOAT4)0);\n"
" outValue2=fmax(outValue2,(FLOAT4)0);\n"
" outValue3=fmax(outValue3,(FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" outValue0=clamp(outValue0,(FLOAT4)0,(FLOAT4)6);\n"
" outValue1=clamp(outValue1,(FLOAT4)0,(FLOAT4)6);\n"
" outValue2=clamp(outValue2,(FLOAT4)0,(FLOAT4)6);\n"
" outValue3=clamp(outValue3,(FLOAT4)0,(FLOAT4)6);\n"
"#endif\n"
" const int remain=outputShape.y-outWidthBlockidx4;\n"
" int outWidthIdx=mul24(outChannelBlockIdx,outputShape.y)+outWidthBlockidx4;\n"
" if (remain >= 4) {\n"
" WI_F(output,(int2)(outWidthIdx,outHeightBlockIdx),outValue0);\n"
" WI_F(output,(int2)(outWidthIdx+1,outHeightBlockIdx),outValue1);\n"
" WI_F(output,(int2)(outWidthIdx+2,outHeightBlockIdx),outValue2);\n"
" WI_F(output,(int2)(outWidthIdx+3,outHeightBlockIdx),outValue3);\n"
" } else if (remain == 3) {\n"
" WI_F(output,(int2)(outWidthIdx,outHeightBlockIdx),outValue0);\n"
" WI_F(output,(int2)(outWidthIdx+1,outHeightBlockIdx),outValue1);\n"
" WI_F(output,(int2)(outWidthIdx+2,outHeightBlockIdx),outValue2);\n"
" } else if (remain == 2) {\n"
" WI_F(output,(int2)(outWidthIdx,outHeightBlockIdx),outValue0);\n"
" WI_F(output,(int2)(outWidthIdx+1,outHeightBlockIdx),outValue1);\n"
" } else if (remain == 1) {\n"
" WI_F(output,(int2)(outWidthIdx,outHeightBlockIdx),outValue0);\n"
" }\n"
"}\n"
"__kernel\n"
"#if SET_ATTRIBUTE\n"
"__attribute__((work_group_size_hint(16,16,1)))\n"
"#endif\n"
"void depthwise_conv2d(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,__read_only image2d_t filter,\n"
" #ifndef NO_BIAS\n"
" __read_only image2d_t bias,\n"
" #endif\n"
" __write_only image2d_t output,\n"
" __private const int2 inputShape,\n"
" __private const int inChannelBlocks,__private const int2 outputShape,\n"
" __private const int2 filterShape,\n"
" __private const int2 paddingShape,\n"
" __private const int2 dilationShape,\n"
" __private const int2 strideShape) {\n"
" const int outChannelWidthIdx=get_global_id(0);\n"
" const int outHeightIdx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(outChannelWidthIdx,outHeightIdx);\n"
" int ow4=(outputShape.y+3)/4;\n"
" const int outChannelBlockIdx=outChannelWidthIdx/ow4;\n"
" const int outWidthBlockidx=outChannelWidthIdx % ow4;\n"
" const int inChannelBlockIdx=outChannelBlockIdx;\n"
" #ifndef NO_BIAS\n"
" FLOAT4 outValue0=RI_F(bias,SAMPLER,(int2)(outChannelBlockIdx,0));\n"
" #else\n"
" FLOAT4 outValue0=(FLOAT4)(0.0f);\n"
" #endif\n"
" FLOAT4 outValue1=outValue0;\n"
" FLOAT4 outValue2=outValue0;\n"
" FLOAT4 outValue3=outValue0;\n"
" const int inWidthOffset0=mad24(outWidthBlockidx,strideShape.y << 2,-paddingShape.y);\n"
" const int inWidthOffset1=inWidthOffset0+strideShape.y;\n"
" const int inWidthOffset2=inWidthOffset1+strideShape.y;\n"
" const int inWidthOffset3=inWidthOffset2+strideShape.y;\n"
" int heightIdx=mad24(outHeightIdx % outputShape.x,strideShape.x,-paddingShape.x);\n"
" const int outBatchIdx=mul24((outHeightIdx/outputShape.x),inputShape.x);\n"
" const int inCurIdx=mul24(inChannelBlockIdx,inputShape.y);\n"
" for (int kh=0; kh<filterShape.x; kh++) {\n"
" int inHeightIdx=select(heightIdx+outBatchIdx,-1,(heightIdx<0 || heightIdx >= inputShape.x));\n"
" heightIdx += dilationShape.x;\n"
" for (int kw=0; kw<filterShape.y; kw++) {\n"
" int filterIdx=mad24(kh,filterShape.y,kw);\n"
" FLOAT4 inValue0,inValue1,inValue2,inValue3;\n"
" int inWidthIdx=mul24(kw,dilationShape.y);\n"
" READ_INPUT_IMAGE(0,inWidthIdx);\n"
" READ_INPUT_IMAGE(1,inWidthIdx);\n"
" READ_INPUT_IMAGE(2,inWidthIdx);\n"
" READ_INPUT_IMAGE(3,inWidthIdx);\n"
" FLOAT4 weights=RI_F(filter,SAMPLER,(int2)(filterIdx,inChannelBlockIdx));\n"
" outValue0=mad(inValue0,weights,outValue0);\n"
" outValue1=mad(inValue1,weights,outValue1);\n"
" outValue2=mad(inValue2,weights,outValue2);\n"
" outValue3=mad(inValue3,weights,outValue3);\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" outValue0=fmax(outValue0,(FLOAT4)0);\n"
" outValue1=fmax(outValue1,(FLOAT4)0);\n"
" outValue2=fmax(outValue2,(FLOAT4)0);\n"
" outValue3=fmax(outValue3,(FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" outValue0=clamp(outValue0,(FLOAT4)0,(FLOAT4)6);\n"
" outValue1=clamp(outValue1,(FLOAT4)0,(FLOAT4)6);\n"
" outValue2=clamp(outValue2,(FLOAT4)0,(FLOAT4)6);\n"
" outValue3=clamp(outValue3,(FLOAT4)0,(FLOAT4)6);\n"
"#endif\n"
" const int outWidthBlockidx4=outWidthBlockidx << 2;\n"
" const int remain=outputShape.y-outWidthBlockidx4;\n"
" int outWidthIdx=mul24(outChannelBlockIdx,outputShape.y)+outWidthBlockidx4;\n"
" if (remain >= 4) {\n"
" WI_F(output,(int2)(outWidthIdx,outHeightIdx),outValue0);\n"
" WI_F(output,(int2)(outWidthIdx+1,outHeightIdx),outValue1);\n"
" WI_F(output,(int2)(outWidthIdx+2,outHeightIdx),outValue2);\n"
" WI_F(output,(int2)(outWidthIdx+3,outHeightIdx),outValue3);\n"
" } else if (remain == 3) {\n"
" WI_F(output,(int2)(outWidthIdx,outHeightIdx),outValue0);\n"
" WI_F(output,(int2)(outWidthIdx+1,outHeightIdx),outValue1);\n"
" WI_F(output,(int2)(outWidthIdx+2,outHeightIdx),outValue2);\n"
" } else if (remain == 2) {\n"
" WI_F(output,(int2)(outWidthIdx,outHeightIdx),outValue0);\n"
" WI_F(output,(int2)(outWidthIdx+1,outHeightIdx),outValue1);\n"
" } else if (remain == 1) {\n"
" WI_F(output,(int2)(outWidthIdx,outHeightIdx),outValue0);\n"
" }\n"
"}\n"
;
const char* layernorm = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void layernorm_w(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __read_only image2d_t input,\n"
" __write_only image2d_t output,\n"
" __private const int width,\n"
" __private const int height,\n"
" __private const int channel,\n"
"#ifdef GAMMA_BETA\n"
" __global const FLOAT *gamma,\n"
" __global const FLOAT *beta,\n"
"#endif\n"
" __private float epsilon){\n"
" int3 pos=(int3)(get_global_id(0),get_global_id(1),get_global_id(2));\n"
" float4 local sum[LOCAL_SIZE];\n"
" if (pos.x<global_dim0 && pos.y<global_dim1 && pos.z<global_dim2) {\n"
" const int h=pos.y % height;\n"
" const int c=pos.y/height;\n"
" const int b=pos.z;\n"
" const int lid=get_local_id(0);\n"
" const int bh_offset=mad24(b,height,h);\n"
" float4 in_sum=0;\n"
"#ifdef RMSNORM\n"
" float4 mean=0;\n"
"#else\n"
" for(int i=lid; i<width; i+=LOCAL_SIZE){\n"
" float4 in=convert_float4(RI_F(input,SAMPLER,(int2)(c*width+i,bh_offset)));\n"
" in_sum += in;\n"
" }\n"
" sum[lid]=in_sum;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=sum[lid]+sum[lid+i];\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" \n"
" float4 mean=sum[0]/(float4)width;\n"
"#endif\n"
" in_sum=0;\n"
" for(int i=lid; i<width; i+=LOCAL_SIZE){\n"
" float4 in=convert_float4(RI_F(input,SAMPLER,(int2)(c*width+i,bh_offset)));\n"
" in_sum += (in-mean)*(in-mean);\n"
" }\n"
" sum[lid]=in_sum;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=sum[lid]+sum[lid+i];\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" float4 square_sum=sum[0]/(float4)width;\n"
" float4 value=(float4)1.0f/(float4)sqrt(square_sum+(float4)epsilon);\n"
" for(int i=lid; i<width; i+=LOCAL_SIZE){\n"
" float4 in=convert_float4(RI_F(input,SAMPLER,(int2)(c*width+i,bh_offset)));\n"
"#ifdef GAMMA_BETA\n"
" float4 out=(in-mean)*value*(float4)gamma[i]+(float4)beta[i];\n"
"#else\n"
" float4 out=(in-mean)*value;\n"
"#endif\n"
" WI_F(output,(int2)(c*width+i,bh_offset),CONVERT_FLOAT4(out));\n"
" }\n"
" }\n"
"}\n"
"__kernel void layernorm_hw(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __read_only image2d_t input,\n"
" __write_only image2d_t output,\n"
" __private const int width,\n"
" __private const int height,\n"
" __private const int channel,\n"
"#ifdef GAMMA_BETA\n"
" __global const FLOAT *gamma,\n"
" __global const FLOAT *beta,\n"
"#endif\n"
" __private float epsilon){\n"
" int3 pos=(int3)(get_global_id(0),get_global_id(1),get_global_id(2));\n"
" float4 local sum[LOCAL_SIZE];\n"
" if (pos.x<global_dim0 && pos.y<global_dim1 && pos.z<global_dim2) {\n"
" const int c=pos.y;\n"
" const int b=pos.z;\n"
" const int height_width=height*width;\n"
" const int lid=get_local_id(0);\n"
" float4 in_sum=0;\n"
"#ifdef RMSNORM\n"
" float4 mean=0;\n"
"#else\n"
" for(int i=lid; i<height_width; i+=LOCAL_SIZE){\n"
" int w=i % width;\n"
" int h=i/width;\n"
" float4 in=convert_float4(RI_F(input,SAMPLER,(int2)(c*width+w,b*height+h)));\n"
" in_sum += in;\n"
" }\n"
" sum[lid]=in_sum;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=sum[lid]+sum[lid+i];\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" \n"
" float4 mean=sum[0]/(float4)height_width;\n"
"#endif\n"
" in_sum=0;\n"
" for(int i=lid; i<height_width; i+=LOCAL_SIZE){\n"
" int w=i % width;\n"
" int h=i/width;\n"
" float4 in=convert_float4(RI_F(input,SAMPLER,(int2)(c*width+w,b*height+h)));\n"
" in_sum += (in-mean)*(in-mean);\n"
" }\n"
" sum[lid]=in_sum;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=sum[lid]+sum[lid+i];\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" float4 square_sum=sum[0]/(float4)height_width;\n"
" float4 value=(float4)1.0f/(float4)sqrt(square_sum+(float4)epsilon);\n"
" for(int i=lid; i<height_width; i+=LOCAL_SIZE){\n"
" int w=i % width;\n"
" int h=i/width;\n"
" float4 in=convert_float4(RI_F(input,SAMPLER,(int2)(c*width+w,b*height+h)));\n"
"#ifdef GAMMA_BETA\n"
" float4 out=(in-mean)*value*(float4)gamma[i]+(float4)beta[i];\n"
"#else\n"
" float4 out=(in-mean)*value;\n"
"#endif\n"
" WI_F(output,(int2)(c*width+w,b*height+h),CONVERT_FLOAT4(out));\n"
" }\n"
" }\n"
"}\n"
"__kernel void layernorm_chw(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __read_only image2d_t input,\n"
" __write_only image2d_t output,\n"
" __private const int width,\n"
" __private const int height,\n"
" __private const int channel,\n"
"#ifdef GAMMA_BETA\n"
" __global const FLOAT *gamma,\n"
" __global const FLOAT *beta,\n"
"#endif\n"
" __private float epsilon){\n"
" int3 pos=(int3)(get_global_id(0),get_global_id(1),get_global_id(2));\n"
" float local sum[LOCAL_SIZE];\n"
" if (pos.x<global_dim0 && pos.y<global_dim1 && pos.z<global_dim2) {\n"
" const int b=pos.z;\n"
" const int sum_size=width*height*channel;\n"
" const int reduce_size=width*height;\n"
" const int lid=get_local_id(0);\n"
" const int channel4=(channel+3)/4;\n"
" const int channel_remain=channel-(channel4-1)*4;\n"
" \n"
" float4 in_sum=0;\n"
" float4 in_sum_left=0;\n"
" float *in_sum_left_ptr=(float*)(&in_sum_left);\n"
"#ifdef RMSNORM\n"
" float4 mean=0;\n"
"#else\n"
" for(int c=0; c<channel4-1; ++c){\n"
" for(int i=lid; i<reduce_size; i+=LOCAL_SIZE){\n"
" int w=i % width;\n"
" int h=i/width;\n"
" float4 in=convert_float4(RI_F(input,SAMPLER,(int2)(c*width+w,b*height+h)));\n"
" in_sum += in;\n"
" }\n"
" }\n"
" for(int i=lid; i<reduce_size; i+=LOCAL_SIZE){\n"
" int w=i % width;\n"
" int h=i/width;\n"
" float4 in=convert_float4(RI_F(input,SAMPLER,(int2)((channel4-1)*width+w,b*height+h)));\n"
" in_sum_left += in;\n"
" }\n"
" in_sum.x=in_sum.x+in_sum.y+in_sum.z+in_sum.w;\n"
" for(int i=1; i<channel_remain; ++i){\n"
" in_sum_left_ptr[0] += in_sum_left_ptr[i];\n"
" }\n"
" sum[lid]=in_sum.x+in_sum_left.x;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=sum[lid]+sum[lid+i];\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" \n"
" float4 mean=sum[0]/(float4)sum_size;\n"
"#endif\n"
" in_sum=0;\n"
" in_sum_left=0;\n"
" for(int c=0; c<channel4-1; ++c){\n"
" for(int i=lid; i<reduce_size; i+=LOCAL_SIZE){\n"
" int w=i % width;\n"
" int h=i/width;\n"
" float4 in=convert_float4(RI_F(input,SAMPLER,(int2)(c*width+w,b*height+h)));\n"
" in_sum += (in-mean)*(in-mean);\n"
" }\n"
" }\n"
" \n"
" for(int i=lid; i<reduce_size; i+=LOCAL_SIZE){\n"
" int w=i % width;\n"
" int h=i/width;\n"
" float4 in=convert_float4(RI_F(input,SAMPLER,(int2)((channel4-1)*width+w,b*height+h)));\n"
" in_sum_left += (in-mean)*(in-mean);\n"
" }\n"
" \n"
" in_sum.x=in_sum.x+in_sum.y+in_sum.z+in_sum.w;\n"
" for(int i=1; i<channel_remain; ++i){\n"
" in_sum_left_ptr[0] += in_sum_left_ptr[i];\n"
" }\n"
" \n"
" sum[lid]=in_sum.x+in_sum_left.x;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=sum[lid]+sum[lid+i];\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" float4 square_sum=sum[0]/(float4)sum_size;\n"
" float4 value=(float4)1.0f/(float4)sqrt(square_sum+(float4)epsilon);\n"
" for(int c=0; c<channel4; ++c){\n"
" for(int i=lid; i<reduce_size; i+=LOCAL_SIZE){\n"
" int w=i % width;\n"
" int h=i/width;\n"
" float4 in=convert_float4(RI_F(input,SAMPLER,(int2)(c*width+w,b*height+h)));\n"
"#ifdef GAMMA_BETA\n"
" float4 out=(in-mean)*value*(float4)gamma[c*reduce_size+i]+(float4)beta[c*reduce_size+i];\n"
"#else\n"
" float4 out=(in-mean)*value;\n"
"#endif\n"
" WI_F(output,(int2)(c*width+w,b*height+h),CONVERT_FLOAT4(out));\n"
" }\n"
" }\n"
" }\n"
"}\n"
;
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* gemm_conv1x1_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_DIM2 "" __private int global_size_dim0,__private int global_size_dim1,\n"
"#define UNIFORM_BOUNDRY_CHECK(index0, index1) "" if(index0 >= global_size_dim0 || index1 >= global_size_dim1) { "" return; "" }\n"
"#define GLOBAL_SIZE_DIM3 "" __private int global_size_dim0,__private int global_size_dim1,__private int global_size_dim2,\n"
"#define UNIFORM_BOUNDRY_CHECK3(index0, index1, index2) "" if(index0 >= global_size_dim0 || index1 >= global_size_dim1 || index2 >= global_size_dim2) { "" return; "" }\n"
"#define UCHAR16_TO_2CHAR16(a, b, c) "" a.s0 = (c.s0 >> 4) - 8; a.s1 = (c.s0 & 15) - 8; a.s2 = (c.s1 >> 4) - 8; a.s3 = (c.s1 & 15) - 8; a.s4 = (c.s2 >> 4) - 8; a.s5 = (c.s2 & 15) - 8; a.s6 = (c.s3 >> 4) - 8; a.s7 = (c.s3 & 15) - 8; "" a.s8 = (c.s4 >> 4) - 8; a.s9 = (c.s4 & 15) - 8; a.sa = (c.s5 >> 4) - 8; a.sb = (c.s5 & 15) - 8; a.sc = (c.s6 >> 4) - 8; a.sd = (c.s6 & 15) - 8; a.se = (c.s7 >> 4) - 8; a.sf = (c.s7 & 15) - 8; "" b.s0 = (c.s8 >> 4) - 8; b.s1 = (c.s8 & 15) - 8; b.s2 = (c.s9 >> 4) - 8; b.s3 = (c.s9 & 15) - 8; b.s4 = (c.sa >> 4) - 8; b.s5 = (c.sa & 15) - 8; b.s6 = (c.sb >> 4) - 8; b.s7 = (c.sb & 15) - 8; "" b.s8=(c.sc >> 4)-8; b.s9=(c.sc & 15)-8; b.sa=(c.sd >> 4)-8; b.sb=(c.sd & 15)-8; b.sc=(c.se >> 4)-8; b.sd=(c.se & 15)-8; b.se=(c.sf >> 4)-8; b.sf=(c.sf & 15)-8;\n"
"#define UCHAR8_TO_CHAR16(a, c) "" a.s0 = (c.s0 >> 4) - 8; a.s1 = (c.s0 & 15) - 8; a.s2 = (c.s1 >> 4) - 8; a.s3 = (c.s1 & 15) - 8; a.s4 = (c.s2 >> 4) - 8; a.s5 = (c.s2 & 15) - 8; a.s6 = (c.s3 >> 4) - 8; a.s7 = (c.s3 & 15) - 8; "" a.s8=(c.s4 >> 4)-8; a.s9=(c.s4 & 15)-8; a.sa=(c.s5 >> 4)-8; a.sb=(c.s5 & 15)-8; a.sc=(c.s6 >> 4)-8; a.sd=(c.s6 & 15)-8; a.se=(c.s7 >> 4)-8; a.sf=(c.s7 & 15)-8;\n"
"#define DOT16X16(a, b, c) "" c += dot(a.s0123, b.s0123); "" c += dot(a.s4567, b.s4567); "" c += dot(a.s89ab, b.s89ab); "" c += dot(a.scdef,b.scdef);\n"
"#if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE)\n"
"#define CHANNEL_PACK 32\n"
"#else\n"
"#define CHANNEL_PACK 16\n"
"#endif\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
"#define WEIGHT_STRIDE 16\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
"#define WEIGHT_STRIDE 8\n"
"#endif\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"#ifdef USE_IMAGE\n"
"inline COMPUTE_FLOAT16 readWeight(__read_only image2d_t weight,int ix,int iy,COMPUTE_FLOAT scale,COMPUTE_FLOAT offset){\n"
" return CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight,SAMPLER,(int2)(ix,iy))))*scale+offset;\n"
"}\n"
"#else\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
"inline COMPUTE_FLOAT16 readWeight(__global const char *weight,int ix,int iy,COMPUTE_FLOAT scale,COMPUTE_FLOAT offset){\n"
" return CONVERT_COMPUTE_FLOAT16(vload16(0,weight))*scale+offset;\n"
"}\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
"inline COMPUTE_FLOAT16 readWeight(__global const uchar *weight,int ix,int iy,COMPUTE_FLOAT scale,COMPUTE_FLOAT offset){\n"
" uchar16 charWeightsInt40=vload16(0,weight);\n"
" uchar8 charWeightsInt4=vload8(0,weight);\n"
" char16 charWeights=0;\n"
" UCHAR8_TO_CHAR16(charWeights,charWeightsInt4);\n"
" return CONVERT_COMPUTE_FLOAT16(charWeights)*scale+offset;\n"
"}\n"
"#endif\n"
"#endif\n"
"__kernel void inverse_quant_weight(GLOBAL_SIZE_DIM2\n"
" #ifdef USE_IMAGE\n"
" __read_only image2d_t weight,\n"
" #else\n"
" #if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *weight,\n"
" #elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar *weight,\n"
" #endif\n"
" #endif\n"
" __global const float *dequantScaleOffset,\n"
" __global FLOAT* output,\n"
" __private const int outputChannelAlign,\n"
" __private const int outputChannel4Align,\n"
" __private const int blockDim){\n"
" const int x=get_global_id(0); //ic\n"
" const int y=get_global_id(1); //oc\n"
" UNIFORM_BOUNDRY_CHECK(x,y);\n"
" #if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE)\n"
" \n"
" const int ic=x << 5;\n"
" const int oc=y << 2;\n"
" const int output_offset=ic*outputChannelAlign+oc;\n"
" int kindex=(ic/blockDim)*outputChannel4Align*2;\n"
" COMPUTE_FLOAT8 ScaleOffset=CONVERT_COMPUTE_FLOAT8(vload8(0,dequantScaleOffset+kindex+oc*2));\n"
" COMPUTE_FLOAT16 weights00,weights01,weights10,weights11,weights20,weights21,weights30,weights31;\n"
" {\n"
" uchar16 charWeightsInt40=as_uchar16(read_imagei(weight,SAMPLER,(int2)(oc,x)));\n"
" uchar16 charWeightsInt41=as_uchar16(read_imagei(weight,SAMPLER,(int2)(oc+1,x)));\n"
" uchar16 charWeightsInt42=as_uchar16(read_imagei(weight,SAMPLER,(int2)(oc+2,x)));\n"
" uchar16 charWeightsInt43=as_uchar16(read_imagei(weight,SAMPLER,(int2)(oc+3,x)));\n"
" char16 charWeights0,charWeights1;\n"
" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt40);\n"
" weights00=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s0+ScaleOffset.s1;\n"
" weights01=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s0+ScaleOffset.s1;\n"
" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt41);\n"
" weights10=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s2+ScaleOffset.s3;\n"
" weights11=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s2+ScaleOffset.s3;\n"
" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt42);\n"
" weights20=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s4+ScaleOffset.s5;\n"
" weights21=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s4+ScaleOffset.s5;\n"
" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt43);\n"
" weights30=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s6+ScaleOffset.s7;\n"
" weights31=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s6+ScaleOffset.s7;\n"
" }\n"
" COMPUTE_FLOAT *weights00_ptr=(COMPUTE_FLOAT *)&weights00;\n"
" COMPUTE_FLOAT *weights10_ptr=(COMPUTE_FLOAT *)&weights10;\n"
" COMPUTE_FLOAT *weights20_ptr=(COMPUTE_FLOAT *)&weights20;\n"
" COMPUTE_FLOAT *weights30_ptr=(COMPUTE_FLOAT *)&weights30;\n"
" COMPUTE_FLOAT *weights01_ptr=(COMPUTE_FLOAT *)&weights01;\n"
" COMPUTE_FLOAT *weights11_ptr=(COMPUTE_FLOAT *)&weights11;\n"
" COMPUTE_FLOAT *weights21_ptr=(COMPUTE_FLOAT *)&weights21;\n"
" COMPUTE_FLOAT *weights31_ptr=(COMPUTE_FLOAT *)&weights31;\n"
" #pragma unroll\n"
" for (int i=0; i<16; ++i){\n"
" FLOAT4 out=CONVERT_FLOAT4((COMPUTE_FLOAT4)(weights00_ptr[i],weights10_ptr[i],weights20_ptr[i],weights30_ptr[i]));\n"
" vstore4(out,0,output+output_offset+i*outputChannelAlign);\n"
" }\n"
" #pragma unroll\n"
" for (int i=0; i<16; ++i){\n"
" FLOAT4 out=CONVERT_FLOAT4((COMPUTE_FLOAT4)(weights01_ptr[i],weights11_ptr[i],weights21_ptr[i],weights31_ptr[i]));\n"
" vstore4(out,0,output+output_offset+(i+16)*outputChannelAlign);\n"
" }\n"
" #else\n"
" const int ic=x << 4;\n"
" const int oc=y << 2;\n"
"#ifndef USE_IMAGE\n"
" #if (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int weight_offset=oc*8;\n"
" int weight_oc_offset=outputChannel4Align*8;\n"
" int weight_stride=8;\n"
" #else\n"
" int weight_offset=oc*16;\n"
" int weight_oc_offset=outputChannel4Align*16;\n"
" int weight_stride=16;\n"
" #endif\n"
"#endif\n"
" const int output_offset=ic*outputChannelAlign+oc;\n"
" int kindex=(ic/blockDim)*outputChannel4Align*2;\n"
" COMPUTE_FLOAT8 ScaleOffset=CONVERT_COMPUTE_FLOAT8(vload8(0,dequantScaleOffset+kindex+oc*2));\n"
" #ifdef USE_IMAGE\n"
" COMPUTE_FLOAT16 weights0=readWeight(weight,oc,x,ScaleOffset.s0,ScaleOffset.s1);\n"
" COMPUTE_FLOAT16 weights1=readWeight(weight,oc+1,x,ScaleOffset.s2,ScaleOffset.s3);\n"
" COMPUTE_FLOAT16 weights2=readWeight(weight,oc+2,x,ScaleOffset.s4,ScaleOffset.s5);\n"
" COMPUTE_FLOAT16 weights3=readWeight(weight,oc+3,x,ScaleOffset.s6,ScaleOffset.s7);\n"
" #else\n"
" COMPUTE_FLOAT16 weights0=readWeight(weight+weight_offset+x*weight_oc_offset,0,0,ScaleOffset.s0,ScaleOffset.s1);\n"
" COMPUTE_FLOAT16 weights1=readWeight(weight+weight_offset+x*weight_oc_offset+weight_stride,0,0,ScaleOffset.s2,ScaleOffset.s3);\n"
" COMPUTE_FLOAT16 weights2=readWeight(weight+weight_offset+x*weight_oc_offset+2*weight_stride,0,0,ScaleOffset.s4,ScaleOffset.s5);\n"
" COMPUTE_FLOAT16 weights3=readWeight(weight+weight_offset+x*weight_oc_offset+3*weight_stride,0,0,ScaleOffset.s6,ScaleOffset.s7);\n"
" #endif\n"
" COMPUTE_FLOAT *weights0_ptr=(COMPUTE_FLOAT*)&weights0;\n"
" COMPUTE_FLOAT *weights1_ptr=(COMPUTE_FLOAT*)&weights1;\n"
" COMPUTE_FLOAT *weights2_ptr=(COMPUTE_FLOAT*)&weights2;\n"
" COMPUTE_FLOAT *weights3_ptr=(COMPUTE_FLOAT*)&weights3;\n"
" #pragma unroll\n"
" for (int i=0; i<16; ++i){\n"
" FLOAT4 out=CONVERT_FLOAT4((COMPUTE_FLOAT4)(weights0_ptr[i],weights1_ptr[i],weights2_ptr[i],weights3_ptr[i]));\n"
" vstore4(out,0,output+output_offset+i*outputChannelAlign);\n"
" }\n"
" #endif\n"
"}\n"
"__kernel void reshape_nchw4_nhwc4(GLOBAL_SIZE_DIM2\n"
"__global const FLOAT* input,\n"
"__global FLOAT* output,\n"
"__private const int bhw,\n"
"__private const int channel,\n"
"__private const int channelAlign){\n"
" const int x=get_global_id(0); //c\n"
" const int y=get_global_id(1); //bhw\n"
" UNIFORM_BOUNDRY_CHECK(x,y);\n"
" \n"
" const int x4=x << 2;\n"
" const int y4=y << 2;\n"
" const int input_offset=(x*bhw+y4)*4;\n"
" FLOAT4 in0=vload4(0,input+input_offset);\n"
" FLOAT4 in1=(y4+1<bhw) ? vload4(0,input+input_offset+4) : (FLOAT4)0;\n"
" FLOAT4 in2=(y4+2<bhw) ? vload4(0,input+input_offset+8) : (FLOAT4)0;\n"
" FLOAT4 in3=(y4+3<bhw) ? vload4(0,input+input_offset+12) : (FLOAT4)0;\n"
" \n"
"#ifdef INPUT_CHANNEL_LEAVE\n"
" if(x4+3 >= channel){\n"
" FLOAT *in0_ptr=(FLOAT*)&in0;\n"
" FLOAT *in1_ptr=(FLOAT*)&in1;\n"
" FLOAT *in2_ptr=(FLOAT*)&in2;\n"
" FLOAT *in3_ptr=(FLOAT*)&in3;\n"
" int remain=x4+3-channel;\n"
" for(int i=remain; i >= 0; i--){\n"
" in0_ptr[3-i]=0;\n"
" in1_ptr[3-i]=0;\n"
" in2_ptr[3-i]=0;\n"
" in3_ptr[3-i]=0;\n"
" }\n"
" }\n"
"#endif\n"
" \n"
"#ifdef FORMAT_CNHW\n"
" int idx=x/4;\n"
" int idy=x % 4;\n"
" const int bhw4=(bhw+3)/4*4;\n"
" int output_offset=((idx*bhw4+y4)*4+idy)*4; // [c/16 b 4 4]\n"
" vstore4(in0,0,output+output_offset);\n"
" vstore4(in1,0,output+output_offset+16);\n"
" vstore4(in2,0,output+output_offset+32);\n"
" vstore4(in3,0,output+output_offset+48);\n"
"#else\n"
" FLOAT16 out=(FLOAT16)(in0.s0,in1.s0,in2.s0,in3.s0,in0.s1,in1.s1,in2.s1,in3.s1,in0.s2,in1.s2,in2.s2,in3.s2,in0.s3,in1.s3,in2.s3,in3.s3);\n"
" const int output_offset=(y*channelAlign+x4)*4;\n"
" vstore16(out,0,output+output_offset);\n"
"#endif\n"
"}\n"
"__kernel void reshape_nhwc4_nchw4(GLOBAL_SIZE_DIM2\n"
"__global const FLOAT* input,\n"
"__global FLOAT* output,\n"
"__private const int bhw,\n"
"__private const int channelAlign){\n"
" const int x=get_global_id(0); //c\n"
" const int y=get_global_id(1); //bhw\n"
" UNIFORM_BOUNDRY_CHECK(x,y);\n"
" \n"
" const int x4=x << 2;\n"
" const int y4=y << 2;\n"
" const int output_offset=(x*bhw+y4)*4;\n"
" \n"
" const int input_offset=(y*channelAlign+x4)*4;\n"
" FLOAT16 in=vload16(0,input+input_offset);\n"
" \n"
" FLOAT4 out0=(FLOAT4)(in.s0,in.s4,in.s8,in.sc);\n"
" FLOAT4 out1=(FLOAT4)(in.s1,in.s5,in.s9,in.sd);\n"
" FLOAT4 out2=(FLOAT4)(in.s2,in.s6,in.sa,in.se);\n"
" FLOAT4 out3=(FLOAT4)(in.s3,in.s7,in.sb,in.sf);\n"
" vstore4(out0,0,output+output_offset);\n"
" if(y4+1 >= bhw) return;\n"
" vstore4(out1,0,output+output_offset+4);\n"
" if(y4+2 >= bhw) return;\n"
" vstore4(out2,0,output+output_offset+8);\n"
" if(y4+3 >= bhw) return;\n"
" vstore4(out3,0,output+output_offset+12);\n"
"}\n"
"__kernel void gemm_b4_c4_buf(GLOBAL_SIZE_DIM2\n"
" __global const FLOAT* input,\n"
"#ifdef USE_IMAGE\n"
" __read_only image2d_t weight,\n"
"#else\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *weight,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar *weight,\n"
"#endif\n"
"#endif\n"
" __global const float *dequantScaleOffset,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT* output,\n"
" __private const int bhw4,\n"
" __private const int dstChannelAlign,\n"
" __private const int srcChannelAlign,\n"
" __private const int blockNum,\n"
" __private const int blockDim) {\n"
" const int x=get_global_id(0); //c\n"
" const int y=get_global_id(1); //b\n"
" UNIFORM_BOUNDRY_CHECK(x,y);\n"
" const int out_c_idx=x << 2;\n"
" const int out_b_idx=y << 2;\n"
" COMPUTE_FLOAT4 bias0=CONVERT_COMPUTE_FLOAT4(vload4(0,bias+out_c_idx));\n"
" COMPUTE_FLOAT4 out=(COMPUTE_FLOAT4)bias0.s0;\n"
" COMPUTE_FLOAT4 out1=(COMPUTE_FLOAT4)bias0.s1,out2=(COMPUTE_FLOAT4)bias0.s2,out3=(COMPUTE_FLOAT4)bias0.s3;\n"
"#ifdef FORMAT_CNHW\n"
" int input_offset=out_b_idx*16;\n"
"#else\n"
" int input_offset=out_b_idx*srcChannelAlign;\n"
"#endif\n"
" int out_offset=out_b_idx*dstChannelAlign+out_c_idx*4;\n"
" \n"
"#ifndef USE_IMAGE\n"
" int weight_offset=out_c_idx*WEIGHT_STRIDE;\n"
" int weight_oc_offset=dstChannelAlign*WEIGHT_STRIDE;\n"
"#endif\n"
" const int loop=(blockDim+CHANNEL_PACK-1)/CHANNEL_PACK;\n"
" \n"
" for (int i=0; i<blockNum; i++){\n"
" int kindex=i*dstChannelAlign*2;\n"
" COMPUTE_FLOAT8 ScaleOffset=CONVERT_COMPUTE_FLOAT8(vload8(0,dequantScaleOffset+kindex+out_c_idx*2));\n"
" for (int j=0; j<loop; j++) {\n"
" int k=i*loop+j;\n"
" #if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE)\n"
" COMPUTE_FLOAT16 weights00,weights01,weights10,weights11,weights20,weights21,weights30,weights31;\n"
" {\n"
" uchar16 charWeightsInt40=as_uchar16(read_imagei(weight,SAMPLER,(int2)(out_c_idx,k)));\n"
" uchar16 charWeightsInt41=as_uchar16(read_imagei(weight,SAMPLER,(int2)(out_c_idx+1,k)));\n"
" uchar16 charWeightsInt42=as_uchar16(read_imagei(weight,SAMPLER,(int2)(out_c_idx+2,k)));\n"
" uchar16 charWeightsInt43=as_uchar16(read_imagei(weight,SAMPLER,(int2)(out_c_idx+3,k)));\n"
" char16 charWeights0,charWeights1;\n"
" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt40);\n"
" weights00=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s0+ScaleOffset.s1;\n"
" weights01=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s0+ScaleOffset.s1;\n"
" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt41);\n"
" weights10=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s2+ScaleOffset.s3;\n"
" weights11=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s2+ScaleOffset.s3;\n"
" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt42);\n"
" weights20=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s4+ScaleOffset.s5;\n"
" weights21=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s4+ScaleOffset.s5;\n"
" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt43);\n"
" weights30=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s6+ScaleOffset.s7;\n"
" weights31=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s6+ScaleOffset.s7;\n"
" }\n"
" #ifdef FORMAT_CNHW\n"
" int k2=k << 1;\n"
" COMPUTE_FLOAT16 in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+k2*bhw4*16));\n"
" DOT16X16(in,weights00,out.s0);\n"
" DOT16X16(in,weights10,out1.s0);\n"
" DOT16X16(in,weights20,out2.s0);\n"
" DOT16X16(in,weights30,out3.s0);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+k2*bhw4*16+16));\n"
" DOT16X16(in,weights00,out.s1);\n"
" DOT16X16(in,weights10,out1.s1);\n"
" DOT16X16(in,weights20,out2.s1);\n"
" DOT16X16(in,weights30,out3.s1);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+k2*bhw4*16+32));\n"
" DOT16X16(in,weights00,out.s2);\n"
" DOT16X16(in,weights10,out1.s2);\n"
" DOT16X16(in,weights20,out2.s2);\n"
" DOT16X16(in,weights30,out3.s2);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+k2*bhw4*16+48));\n"
" DOT16X16(in,weights00,out.s3);\n"
" DOT16X16(in,weights10,out1.s3);\n"
" DOT16X16(in,weights20,out2.s3);\n"
" DOT16X16(in,weights30,out3.s3);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+(k2+1)*bhw4*16));\n"
" DOT16X16(in,weights01,out.s0);\n"
" DOT16X16(in,weights11,out1.s0);\n"
" DOT16X16(in,weights21,out2.s0);\n"
" DOT16X16(in,weights31,out3.s0);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+(k2+1)*bhw4*16+16));\n"
" DOT16X16(in,weights01,out.s1);\n"
" DOT16X16(in,weights11,out1.s1);\n"
" DOT16X16(in,weights21,out2.s1);\n"
" DOT16X16(in,weights31,out3.s1);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+(k2+1)*bhw4*16+32));\n"
" DOT16X16(in,weights01,out.s2);\n"
" DOT16X16(in,weights11,out1.s2);\n"
" DOT16X16(in,weights21,out2.s2);\n"
" DOT16X16(in,weights31,out3.s2);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+(k2+1)*bhw4*16+48));\n"
" DOT16X16(in,weights01,out.s3);\n"
" DOT16X16(in,weights11,out1.s3);\n"
" DOT16X16(in,weights21,out2.s3);\n"
" DOT16X16(in,weights31,out3.s3);\n"
" #else\n"
" int k32=k << 5;\n"
" COMPUTE_FLOAT *weights00_ptr=(COMPUTE_FLOAT *)&weights00;\n"
" COMPUTE_FLOAT *weights10_ptr=(COMPUTE_FLOAT *)&weights10;\n"
" COMPUTE_FLOAT *weights20_ptr=(COMPUTE_FLOAT *)&weights20;\n"
" COMPUTE_FLOAT *weights30_ptr=(COMPUTE_FLOAT *)&weights30;\n"
" COMPUTE_FLOAT *weights01_ptr=(COMPUTE_FLOAT *)&weights01;\n"
" COMPUTE_FLOAT *weights11_ptr=(COMPUTE_FLOAT *)&weights11;\n"
" COMPUTE_FLOAT *weights21_ptr=(COMPUTE_FLOAT *)&weights21;\n"
" COMPUTE_FLOAT *weights31_ptr=(COMPUTE_FLOAT *)&weights31;\n"
" #pragma unroll\n"
" for (int i=0; i<16; ++i){\n"
" COMPUTE_FLOAT4 in=CONVERT_COMPUTE_FLOAT4(vload4(0,input+input_offset+(k32+i)*4));\n"
" out=mad(in,weights00_ptr[i],out);\n"
" out1=mad(in,weights10_ptr[i],out1);\n"
" out2=mad(in,weights20_ptr[i],out2);\n"
" out3=mad(in,weights30_ptr[i],out3);\n"
" }\n"
" #pragma unroll\n"
" for (int i=0; i<16; ++i){\n"
" COMPUTE_FLOAT4 in=CONVERT_COMPUTE_FLOAT4(vload4(0,input+input_offset+(k32+i+16)*4));\n"
" out=mad(in,weights01_ptr[i],out);\n"
" out1=mad(in,weights11_ptr[i],out1);\n"
" out2=mad(in,weights21_ptr[i],out2);\n"
" out3=mad(in,weights31_ptr[i],out3);\n"
" }\n"
" #endif\n"
" #else\n"
" COMPUTE_FLOAT16 weights0,weights1,weights2,weights3;\n"
" #ifdef USE_IMAGE\n"
" weights0=readWeight(weight,out_c_idx,k,ScaleOffset.s0,ScaleOffset.s1);\n"
" weights1=readWeight(weight,out_c_idx+1,k,ScaleOffset.s2,ScaleOffset.s3);\n"
" weights2=readWeight(weight,out_c_idx+2,k,ScaleOffset.s4,ScaleOffset.s5);\n"
" weights3=readWeight(weight,out_c_idx+3,k,ScaleOffset.s6,ScaleOffset.s7);\n"
" #else\n"
" weights0=readWeight(weight+weight_offset+k*weight_oc_offset,0,0,ScaleOffset.s0,ScaleOffset.s1);\n"
" weights1=readWeight(weight+weight_offset+k*weight_oc_offset+WEIGHT_STRIDE,0,0,ScaleOffset.s2,ScaleOffset.s3);\n"
" weights2=readWeight(weight+weight_offset+k*weight_oc_offset+2*WEIGHT_STRIDE,0,0,ScaleOffset.s4,ScaleOffset.s5);\n"
" weights3=readWeight(weight+weight_offset+k*weight_oc_offset+3*WEIGHT_STRIDE,0,0,ScaleOffset.s6,ScaleOffset.s7);\n"
" #endif\n"
" #ifdef FORMAT_CNHW\n"
" COMPUTE_FLOAT16 in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+k*bhw4*16));\n"
" DOT16X16(in,weights0,out.s0);\n"
" DOT16X16(in,weights1,out1.s0);\n"
" DOT16X16(in,weights2,out2.s0);\n"
" DOT16X16(in,weights3,out3.s0);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+k*bhw4*16+16));\n"
" DOT16X16(in,weights0,out.s1);\n"
" DOT16X16(in,weights1,out1.s1);\n"
" DOT16X16(in,weights2,out2.s1);\n"
" DOT16X16(in,weights3,out3.s1);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+k*bhw4*16+32));\n"
" DOT16X16(in,weights0,out.s2);\n"
" DOT16X16(in,weights1,out1.s2);\n"
" DOT16X16(in,weights2,out2.s2);\n"
" DOT16X16(in,weights3,out3.s2);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+k*bhw4*16+48));\n"
" DOT16X16(in,weights0,out.s3);\n"
" DOT16X16(in,weights1,out1.s3);\n"
" DOT16X16(in,weights2,out2.s3);\n"
" DOT16X16(in,weights3,out3.s3);\n"
" #else\n"
" int k16=k << 4;\n"
" COMPUTE_FLOAT *weights0_ptr=(COMPUTE_FLOAT *)&weights0;\n"
" COMPUTE_FLOAT *weights1_ptr=(COMPUTE_FLOAT *)&weights1;\n"
" COMPUTE_FLOAT *weights2_ptr=(COMPUTE_FLOAT *)&weights2;\n"
" COMPUTE_FLOAT *weights3_ptr=(COMPUTE_FLOAT *)&weights3;\n"
" #pragma unroll\n"
" for (int i=0; i<16; ++i){\n"
" COMPUTE_FLOAT4 in=CONVERT_COMPUTE_FLOAT4(vload4(0,input+input_offset+(k16+i)*4));\n"
" out=mad(in,weights0_ptr[i],out);\n"
" out1=mad(in,weights1_ptr[i],out1);\n"
" out2=mad(in,weights2_ptr[i],out2);\n"
" out3=mad(in,weights3_ptr[i],out3);\n"
" }\n"
" #endif\n"
" #endif\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" out=fmax(out,(COMPUTE_FLOAT4)0);\n"
" out1=fmax(out1,(COMPUTE_FLOAT4)0);\n"
" out2=fmax(out2,(COMPUTE_FLOAT4)0);\n"
" out3=fmax(out3,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out=clamp(out,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out1=clamp(out1,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out2=clamp(out2,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out3=clamp(out3,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" vstore4(CONVERT_FLOAT4(out),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out1),0,output+out_offset+4);\n"
" vstore4(CONVERT_FLOAT4(out2),0,output+out_offset+8);\n"
" vstore4(CONVERT_FLOAT4(out3),0,output+out_offset+12);\n"
"}\n"
"__kernel void gemm_b4_c2_buf(GLOBAL_SIZE_DIM2\n"
" __global const FLOAT* input,\n"
"#ifdef USE_IMAGE\n"
" __read_only image2d_t weight,\n"
"#else\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *weight,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar *weight,\n"
"#endif\n"
"#endif\n"
" __global const float *dequantScaleOffset,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT* output,\n"
" __private const int bhw4,\n"
" __private const int dstChannelAlign,\n"
" __private const int srcChannelAlign,\n"
" __private const int blockNum,\n"
" __private const int blockDim) {\n"
" const int x=get_global_id(0); //c\n"
" const int y=get_global_id(1); //b\n"
" UNIFORM_BOUNDRY_CHECK(x,y);\n"
" const int out_c_idx=x << 1;\n"
" const int out_b_idx=y << 2;\n"
" COMPUTE_FLOAT2 bias0=CONVERT_COMPUTE_FLOAT2(vload2(0,bias+out_c_idx));\n"
" COMPUTE_FLOAT4 out=(COMPUTE_FLOAT4)bias0.s0;\n"
" COMPUTE_FLOAT4 out1=(COMPUTE_FLOAT4)bias0.s1;\n"
" \n"
"#ifdef FORMAT_CNHW\n"
" int input_offset=out_b_idx*16;\n"
"#else\n"
" int input_offset=out_b_idx*srcChannelAlign;\n"
"#endif\n"
" int out_offset=out_b_idx*dstChannelAlign+out_c_idx*4;\n"
" \n"
"#ifndef USE_IMAGE\n"
" int weight_offset=out_c_idx*WEIGHT_STRIDE;\n"
" int weight_oc_offset=dstChannelAlign*WEIGHT_STRIDE;\n"
"#endif\n"
" const int loop=(blockDim+CHANNEL_PACK-1)/CHANNEL_PACK;\n"
" for (int i=0; i<blockNum; i++){\n"
" int kindex=i*dstChannelAlign*2;\n"
" COMPUTE_FLOAT4 ScaleOffset=CONVERT_COMPUTE_FLOAT4(vload4(0,dequantScaleOffset+kindex+out_c_idx*2));\n"
" for (int j=0; j<loop; j++) {\n"
" int k=i*loop+j;\n"
" #if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE)\n"
" COMPUTE_FLOAT16 weights00,weights01,weights10,weights11;\n"
" {\n"
" uchar16 charWeightsInt40=as_uchar16(read_imagei(weight,SAMPLER,(int2)(out_c_idx,k)));\n"
" uchar16 charWeightsInt41=as_uchar16(read_imagei(weight,SAMPLER,(int2)(out_c_idx+1,k)));\n"
" char16 charWeights0,charWeights1;\n"
" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt40);\n"
" weights00=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s0+ScaleOffset.s1;\n"
" weights01=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s0+ScaleOffset.s1;\n"
" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt41);\n"
" weights10=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s2+ScaleOffset.s3;\n"
" weights11=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s2+ScaleOffset.s3;\n"
" }\n"
" #ifdef FORMAT_CNHW\n"
" int k2=k << 1;\n"
" COMPUTE_FLOAT16 in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+k2*bhw4*16));\n"
" DOT16X16(in,weights00,out.s0);\n"
" DOT16X16(in,weights10,out1.s0);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+k2*bhw4*16+16));\n"
" DOT16X16(in,weights00,out.s1);\n"
" DOT16X16(in,weights10,out1.s1);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+k2*bhw4*16+32));\n"
" DOT16X16(in,weights00,out.s2);\n"
" DOT16X16(in,weights10,out1.s2);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+k2*bhw4*16+48));\n"
" DOT16X16(in,weights00,out.s3);\n"
" DOT16X16(in,weights10,out1.s3);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+(k2+1)*bhw4*16));\n"
" DOT16X16(in,weights01,out.s0);\n"
" DOT16X16(in,weights11,out1.s0);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+(k2+1)*bhw4*16+16));\n"
" DOT16X16(in,weights01,out.s1);\n"
" DOT16X16(in,weights11,out1.s1);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+(k2+1)*bhw4*16+32));\n"
" DOT16X16(in,weights01,out.s2);\n"
" DOT16X16(in,weights11,out1.s2);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+(k2+1)*bhw4*16+48));\n"
" DOT16X16(in,weights01,out.s3);\n"
" DOT16X16(in,weights11,out1.s3);\n"
" #else\n"
" int k32=k << 5;\n"
" COMPUTE_FLOAT *weights00_ptr=(COMPUTE_FLOAT *)&weights00;\n"
" COMPUTE_FLOAT *weights10_ptr=(COMPUTE_FLOAT *)&weights10;\n"
" COMPUTE_FLOAT *weights01_ptr=(COMPUTE_FLOAT *)&weights01;\n"
" COMPUTE_FLOAT *weights11_ptr=(COMPUTE_FLOAT *)&weights11;\n"
" #pragma unroll\n"
" for (int i=0; i<16; ++i){\n"
" COMPUTE_FLOAT4 in=CONVERT_COMPUTE_FLOAT4(vload4(0,input+input_offset+(k32+i)*4));\n"
" out=mad(in,weights00_ptr[i],out);\n"
" out1=mad(in,weights10_ptr[i],out1);\n"
" }\n"
" #pragma unroll\n"
" for (int i=0; i<16; ++i){\n"
" COMPUTE_FLOAT4 in=CONVERT_COMPUTE_FLOAT4(vload4(0,input+input_offset+(k32+i+16)*4));\n"
" out=mad(in,weights01_ptr[i],out);\n"
" out1=mad(in,weights11_ptr[i],out1);\n"
" }\n"
" #endif\n"
" #else\n"
" COMPUTE_FLOAT16 weights0,weights1;\n"
" #ifdef USE_IMAGE\n"
" weights0=readWeight(weight,out_c_idx,k,ScaleOffset.s0,ScaleOffset.s1);\n"
" weights1=readWeight(weight,out_c_idx+1,k,ScaleOffset.s2,ScaleOffset.s3);\n"
" #else\n"
" weights0=readWeight(weight+weight_offset+k*weight_oc_offset,0,0,ScaleOffset.s0,ScaleOffset.s1);\n"
" weights1=readWeight(weight+weight_offset+k*weight_oc_offset+WEIGHT_STRIDE,0,0,ScaleOffset.s2,ScaleOffset.s3);\n"
" #endif\n"
" #ifdef FORMAT_CNHW\n"
" COMPUTE_FLOAT16 in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+k*bhw4*16));\n"
" DOT16X16(in,weights0,out.s0);\n"
" DOT16X16(in,weights1,out1.s0);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+k*bhw4*16+16));\n"
" DOT16X16(in,weights0,out.s1);\n"
" DOT16X16(in,weights1,out1.s1);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+k*bhw4*16+32));\n"
" DOT16X16(in,weights0,out.s2);\n"
" DOT16X16(in,weights1,out1.s2);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+k*bhw4*16+48));\n"
" DOT16X16(in,weights0,out.s3);\n"
" DOT16X16(in,weights1,out1.s3);\n"
" #else\n"
" int k16=k << 4;\n"
" COMPUTE_FLOAT *weights0_ptr=(COMPUTE_FLOAT *)&weights0;\n"
" COMPUTE_FLOAT *weights1_ptr=(COMPUTE_FLOAT *)&weights1;\n"
" #pragma unroll\n"
" for (int i=0; i<16; ++i){\n"
" COMPUTE_FLOAT4 in=CONVERT_COMPUTE_FLOAT4(vload4(0,input+input_offset+(k16+i)*4));\n"
" out=mad(in,weights0_ptr[i],out);\n"
" out1=mad(in,weights1_ptr[i],out1);\n"
" }\n"
" #endif\n"
" #endif\n"
" }\n"
" }\n"
" \n"
"#ifdef RELU\n"
" out=fmax(out,(COMPUTE_FLOAT4)0);\n"
" out1=fmax(out1,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out=clamp(out,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
" out1=clamp(out1,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" vstore4(CONVERT_FLOAT4(out),0,output+out_offset);\n"
" vstore4(CONVERT_FLOAT4(out1),0,output+out_offset+4);\n"
"}\n"
"__kernel void gemm_b4_c1_buf(GLOBAL_SIZE_DIM2\n"
" __global const FLOAT* input,\n"
"#ifdef USE_IMAGE\n"
" __read_only image2d_t weight,\n"
"#else\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *weight,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar *weight,\n"
"#endif\n"
"#endif\n"
" __global const float *dequantScaleOffset,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT* output,\n"
" __private const int bhw4,\n"
" __private const int dstChannelAlign,\n"
" __private const int srcChannelAlign,\n"
" __private const int blockNum,\n"
" __private const int blockDim) {\n"
" const int x=get_global_id(0); //c\n"
" const int y=get_global_id(1); //b\n"
" UNIFORM_BOUNDRY_CHECK(x,y);\n"
" const int out_c_idx=x;\n"
" const int out_b_idx=y << 2;\n"
" COMPUTE_FLOAT bias0=bias[out_c_idx];\n"
" COMPUTE_FLOAT4 out=(COMPUTE_FLOAT4)bias0;\n"
" \n"
"#ifdef FORMAT_CNHW\n"
" int input_offset=out_b_idx*16;\n"
"#else\n"
" int input_offset=out_b_idx*srcChannelAlign;\n"
"#endif\n"
" int out_offset=out_b_idx*dstChannelAlign+out_c_idx*4;\n"
"#ifndef USE_IMAGE\n"
" int weight_offset=out_c_idx*WEIGHT_STRIDE;\n"
" int weight_oc_offset=dstChannelAlign*WEIGHT_STRIDE;\n"
"#endif\n"
" const int loop=(blockDim+CHANNEL_PACK-1)/CHANNEL_PACK;\n"
" \n"
" for (int i=0; i<blockNum; i++){\n"
" int kindex=i*dstChannelAlign*2;\n"
" COMPUTE_FLOAT2 ScaleOffset=CONVERT_COMPUTE_FLOAT2(vload2(out_c_idx,dequantScaleOffset+kindex));\n"
" for (int j=0; j<loop; j++) {\n"
" int k=i*loop+j;\n"
" #if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE)\n"
" COMPUTE_FLOAT16 weights00,weights01,weights10,weights11;\n"
" {\n"
" uchar16 charWeightsInt40=as_uchar16(read_imagei(weight,SAMPLER,(int2)(out_c_idx,k)));\n"
" char16 charWeights0,charWeights1;\n"
" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt40);\n"
" weights00=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s0+ScaleOffset.s1;\n"
" weights01=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s0+ScaleOffset.s1;\n"
" }\n"
" #ifdef FORMAT_CNHW\n"
" int k2=k << 1;\n"
" COMPUTE_FLOAT16 in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+k2*bhw4*16));\n"
" DOT16X16(in,weights00,out.s0);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+k2*bhw4*16+16));\n"
" DOT16X16(in,weights00,out.s1);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+k2*bhw4*16+32));\n"
" DOT16X16(in,weights00,out.s2);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+k2*bhw4*16+48));\n"
" DOT16X16(in,weights00,out.s3);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+(k2+1)*bhw4*16));\n"
" DOT16X16(in,weights01,out.s0);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+(k2+1)*bhw4*16+16));\n"
" DOT16X16(in,weights01,out.s1);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+(k2+1)*bhw4*16+32));\n"
" DOT16X16(in,weights01,out.s2);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+(k2+1)*bhw4*16+48));\n"
" DOT16X16(in,weights01,out.s3);\n"
" #else\n"
" int k32=k << 5;\n"
" COMPUTE_FLOAT *weights00_ptr=(COMPUTE_FLOAT *)&weights00;\n"
" COMPUTE_FLOAT *weights01_ptr=(COMPUTE_FLOAT *)&weights01;\n"
" #pragma unroll\n"
" for (int i=0; i<16; ++i){\n"
" COMPUTE_FLOAT4 in=CONVERT_COMPUTE_FLOAT4(vload4(0,input+input_offset+(k32+i)*4));\n"
" out=mad(in,weights00_ptr[i],out);\n"
" }\n"
" #pragma unroll\n"
" for (int i=0; i<16; ++i){\n"
" COMPUTE_FLOAT4 in=CONVERT_COMPUTE_FLOAT4(vload4(0,input+input_offset+(k32+i+16)*4));\n"
" out=mad(in,weights01_ptr[i],out);\n"
" }\n"
" #endif\n"
" #else\n"
" COMPUTE_FLOAT16 weights;\n"
" #ifdef USE_IMAGE\n"
" weights=readWeight(weight,out_c_idx,k,ScaleOffset.s0,ScaleOffset.s1);\n"
" #else\n"
" weights=readWeight(weight+weight_offset+k*weight_oc_offset,0,0,ScaleOffset.s0,ScaleOffset.s1);\n"
" #endif\n"
" #ifdef FORMAT_CNHW\n"
" COMPUTE_FLOAT16 in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+k*bhw4*16));\n"
" DOT16X16(in,weights,out.s0);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+k*bhw4*16+16));\n"
" DOT16X16(in,weights,out.s1);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+k*bhw4*16+32));\n"
" DOT16X16(in,weights,out.s2);\n"
" in=CONVERT_COMPUTE_FLOAT16(vload16(0,input+input_offset+k*bhw4*16+48));\n"
" DOT16X16(in,weights,out.s3);\n"
" #else\n"
" int k16=k << 4;\n"
" COMPUTE_FLOAT *weights_ptr=(COMPUTE_FLOAT *)&weights;\n"
" #pragma unroll\n"
" for (int i=0; i<16; ++i){\n"
" COMPUTE_FLOAT4 in=CONVERT_COMPUTE_FLOAT4(vload4(0,input+input_offset+(k16+i)*4));\n"
" out=mad(in,weights_ptr[i],out);\n"
" }\n"
" #endif\n"
" #endif\n"
" }\n"
" }\n"
" \n"
"#ifdef RELU\n"
" out=fmax(out,(COMPUTE_FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out=clamp(out,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n"
"#endif\n"
" vstore4(CONVERT_FLOAT4(out),0,output+out_offset);\n"
"}\n"
;
#endif
const char* winogradTransformDest2_5_1 = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void winogradTransformDest(__read_only image2d_t uInput,// 0\n"
" __read_only image2d_t uBias,__write_only image2d_t uOutput,\n"
" __private const int unitWidth,// 3\n"
" __private const int unitHeight,__private const int dstWidth,\n"
" __private const int dstHeight,// 6\n"
" __private const int dstChannelC4,__private const int batchOffset) {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1)); \n"
" if (pos.x<unitWidth*unitHeight && pos.y<dstChannelC4) {\n"
" int unitWidth_idx=pos.x % unitWidth;\n"
" int unitHeight_idx=pos.x/unitWidth;\n"
" int srcY=pos.y*unitHeight+unitHeight_idx;\n"
" FLOAT4 bias=RI_F(uBias,SAMPLER,(int2)(pos.y,0));\n"
" \n"
" {\n"
" int oyStart=unitHeight_idx*2;\n"
" int oxStart=unitWidth_idx*2;\n"
" FLOAT4 S00=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*0,srcY));\n"
" FLOAT4 S10=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*1,srcY));\n"
" FLOAT4 S20=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*2,srcY));\n"
" FLOAT4 S30=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*3,srcY));\n"
" FLOAT4 S40=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*4,srcY));\n"
" FLOAT4 S50=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*5,srcY));\n"
" FLOAT4 S01=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*6,srcY));\n"
" FLOAT4 S11=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*7,srcY));\n"
" FLOAT4 S21=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*8,srcY));\n"
" FLOAT4 S31=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*9,srcY));\n"
" FLOAT4 S41=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*10,srcY));\n"
" FLOAT4 S51=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*11,srcY));\n"
" FLOAT4 S02=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*12,srcY));\n"
" FLOAT4 S12=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*13,srcY));\n"
" FLOAT4 S22=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*14,srcY));\n"
" FLOAT4 S32=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*15,srcY));\n"
" FLOAT4 S42=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*16,srcY));\n"
" FLOAT4 S52=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*17,srcY));\n"
" FLOAT4 S03=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*18,srcY));\n"
" FLOAT4 S13=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*19,srcY));\n"
" FLOAT4 S23=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*20,srcY));\n"
" FLOAT4 S33=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*21,srcY));\n"
" FLOAT4 S43=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*22,srcY));\n"
" FLOAT4 S53=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*23,srcY));\n"
" FLOAT4 S04=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*24,srcY));\n"
" FLOAT4 S14=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*25,srcY));\n"
" FLOAT4 S24=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*26,srcY));\n"
" FLOAT4 S34=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*27,srcY));\n"
" FLOAT4 S44=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*28,srcY));\n"
" FLOAT4 S54=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*29,srcY));\n"
" FLOAT4 S05=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*30,srcY));\n"
" FLOAT4 S15=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*31,srcY));\n"
" FLOAT4 S25=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*32,srcY));\n"
" FLOAT4 S35=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*33,srcY));\n"
" FLOAT4 S45=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*34,srcY));\n"
" FLOAT4 S55=RI_F(uInput,SAMPLER,(int2)(unitWidth_idx+unitWidth*35,srcY));\n"
" FLOAT4 m00=+S00+S01+S02+S03+S04;\n"
" FLOAT4 m10=+S10+S11+S12+S13+S14;\n"
" FLOAT4 m20=+S20+S21+S22+S23+S24;\n"
" FLOAT4 m30=+S30+S31+S32+S33+S34;\n"
" FLOAT4 m40=+S40+S41+S42+S43+S44;\n"
" FLOAT4 m50=+S50+S51+S52+S53+S54;\n"
" FLOAT4 m01=+S01-S02+(FLOAT)2.0*S03-(FLOAT)2.0*S04+S05;\n"
" FLOAT4 m11=+S11-S12+(FLOAT)2.0*S13-(FLOAT)2.0*S14+S15;\n"
" FLOAT4 m21=+S21-S22+(FLOAT)2.0*S23-(FLOAT)2.0*S24+S25;\n"
" FLOAT4 m31=+S31-S32+(FLOAT)2.0*S33-(FLOAT)2.0*S34+S35;\n"
" FLOAT4 m41=+S41-S42+(FLOAT)2.0*S43-(FLOAT)2.0*S44+S45;\n"
" FLOAT4 m51=+S51-S52+(FLOAT)2.0*S53-(FLOAT)2.0*S54+S55;\n"
" {\n"
" int ox=oxStart+0;\n"
" int oy=oyStart+0;\n"
" if (ox<dstWidth && oy<dstHeight) {\n"
" int imageOx=ox+pos.y*dstWidth;\n"
" int imageOy=oy+batchOffset*dstHeight;\n"
" FLOAT4 res=bias+m00+m10+m20+m30+m40;\n"
"#ifdef RELU\n"
" res=max(res,(FLOAT4)(0));\n"
"#endif\n"
"#ifdef RELU6\n"
" res=clamp(res,(FLOAT4)(0),(FLOAT4)(6));\n"
"#endif\n"
" WI_F(uOutput,(int2)(imageOx,imageOy),res);\n"
" }\n"
" }\n"
" {\n"
" int ox=oxStart+1;\n"
" int oy=oyStart+0;\n"
" if (ox<dstWidth && oy<dstHeight) {\n"
" int imageOx=ox+pos.y*dstWidth;\n"
" int imageOy=oy+batchOffset*dstHeight;\n"
" FLOAT4 res=bias+m10-m20+(FLOAT)2.0*m30-(FLOAT)2.0*m40+m50;\n"
"#ifdef RELU\n"
" res=max(res,(FLOAT4)(0));\n"
"#endif\n"
"#ifdef RELU6\n"
" res=clamp(res,(FLOAT4)(0),(FLOAT4)(6));\n"
"#endif\n"
" WI_F(uOutput,(int2)(imageOx,imageOy),res);\n"
" }\n"
" }\n"
" {\n"
" int ox=oxStart+0;\n"
" int oy=oyStart+1;\n"
" if (ox<dstWidth && oy<dstHeight) {\n"
" int imageOx=ox+pos.y*dstWidth;\n"
" int imageOy=oy+batchOffset*dstHeight;\n"
" FLOAT4 res=bias+m01+m11+m21+m31+m41;\n"
"#ifdef RELU\n"
" res=max(res,(FLOAT4)(0));\n"
"#endif\n"
"#ifdef RELU6\n"
" res=clamp(res,(FLOAT4)(0),(FLOAT4)(6));\n"
"#endif\n"
" WI_F(uOutput,(int2)(imageOx,imageOy),res);\n"
" }\n"
" }\n"
" {\n"
" int ox=oxStart+1;\n"
" int oy=oyStart+1;\n"
" if (ox<dstWidth && oy<dstHeight) {\n"
" int imageOx=ox+pos.y*dstWidth;\n"
" int imageOy=oy+batchOffset*dstHeight;\n"
" FLOAT4 res=bias+m11-m21+(FLOAT4)2.0*m31-(FLOAT4)2.0*m41+m51;\n"
"#ifdef RELU\n"
" res=max(res,(FLOAT4)(0));\n"
"#endif\n"
"#ifdef RELU6\n"
" res=clamp(res,(FLOAT4)(0),(FLOAT4)(6));\n"
"#endif\n"
" WI_F(uOutput,(int2)(imageOx,imageOy),res);\n"
" }\n"
" }\n"
" }\n"
" }\n"
"}\n"
;
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* cast_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_2_DIMS ""__private const int global_size_dim0,__private const int global_size_dim1,\n"
"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n"
"__kernel void cast_buf(GLOBAL_SIZE_2_DIMS\n"
" __global INPUT_TYPE* input,\n"
" __global OUTPUT_TYPE* output,\n"
" __private const int size\n"
" ) {\n"
" const int idx=get_global_id(0);\n"
" const int idy=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(idx,idy);\n"
" const int inp_offset=idx*4;\n"
"#ifdef PACK_LEAVE\n"
" if(inp_offset+3 >= size){\n"
" int remain=size-inp_offset;\n"
" for(int i=0; i<remain; ++i){\n"
" #ifdef TO_BOOL\n"
" int value=(int)input[inp_offset+i];\n"
" value=value == 0 ? 0 : 1;\n"
" output[inp_offset+i]=(OUTPUT_TYPE)value;\n"
" #else\n"
" output[inp_offset+i]=(OUTPUT_TYPE)input[inp_offset+i];\n"
" #endif\n"
" }\n"
" }else {\n"
"#endif\n"
" #ifdef TO_BOOL\n"
" int4 value=convert_int4(vload4(0,input+inp_offset));\n"
" value=value == (int4)0 ? (int4)0 : (int4)1;\n"
" vstore4(CONVERT_OUTPUT4(value),0,output+inp_offset);\n"
" #else\n"
" vstore4(CONVERT_OUTPUT4(vload4(0,input+inp_offset)),0,output+inp_offset);\n"
" #endif\n"
"#ifdef PACK_LEAVE\n"
" }\n"
"#endif\n"
"}\n"
;
#endif
const char* reduction = 
"// TODO: use INIT_SCALAR_VALUE,OPERATOR,FINAL_OPERATOR_ON_CHANNEL macro abstract and simplify code\n"
"// TODO: support reduce dims include batch\n"
"// TODO: support keep_dim=False\n"
"// TODO: fix channel reduce result re-pack problem\n"
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define GLOBAL_SIZE_2_DIMS ""__private const int global_size_dim0,__private const int global_size_dim1,\n"
"#define GLOBAL_SIZE_3_DIMS ""__private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void reduct_width(GLOBAL_SIZE_3_DIMS\n"
" __read_only image2d_t input,\n"
" __write_only image2d_t output,\n"
" __private const int inputWidth,\n"
" __private const int inputHeight,\n"
" __private const int inputChannel,\n"
" __private const int inputBatch,\n"
" __private const int inputChannelBlock,\n"
" __private const int oututWidth,\n"
" __private const int outputHeight,\n"
" __private const int outputChannel,\n"
" __private const int outputChannelBlock\n"
" ) {\n"
" const int width_idx=get_global_id(0);\n"
" const int height_idx=get_global_id(1);\n"
" const int batch_channel_idx=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(width_idx,height_idx,batch_channel_idx);\n"
" \n"
" const int batch_idx=batch_channel_idx/outputChannelBlock;\n"
" const int channel_idx=batch_channel_idx % outputChannelBlock;\n"
" const int bh=batch_idx*inputHeight+height_idx;\n"
" const int wc=channel_idx*inputWidth;\n"
" INPUT_TYPE_I4 out=(INPUT_TYPE_I4)VALUE;\n"
" \n"
"#if LOCAL_SIZE>0\n"
" const int lid=get_local_id(0);\n"
" INPUT_TYPE_I4 local sum[LOCAL_SIZE];\n"
" for(int i=lid; i<inputWidth; i+=LOCAL_SIZE){\n"
" INPUT_TYPE_I4 in=RI_DATA(input,SAMPLER,(int2)(wc+i,bh));\n"
" out=OPERATE(out,in);\n"
" }\n"
" sum[lid]=out;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=OPERATE(sum[lid],sum[lid+i]);\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" out=sum[0];\n"
"#else\n"
" for(int i=0; i<inputWidth; ++i){\n"
" INPUT_TYPE_I4 in=RI_DATA(input,SAMPLER,(int2)(wc+i,bh));\n"
" out=OPERATE(out,in);\n"
" }\n"
"#endif\n"
"#ifdef GET_AVG\n"
" out=out/inputWidth;\n"
"#endif\n"
" WI_DATA(output,(int2)(channel_idx,bh),CONVERT_OUTPUT_I4(out));\n"
"}\n"
"__kernel void reduct_height(GLOBAL_SIZE_3_DIMS\n"
" __read_only image2d_t input,\n"
" __write_only image2d_t output,\n"
" __private const int inputWidth,\n"
" __private const int inputHeight,\n"
" __private const int inputChannel,\n"
" __private const int inputBatch,\n"
" __private const int inputChannelBlock,\n"
" __private const int oututWidth,\n"
" __private const int outputHeight,\n"
" __private const int outputChannel,\n"
" __private const int outputChannelBlock\n"
" ) {\n"
"#if LOCAL_SIZE>0\n"
" const int width_local_idx=get_global_id(0);\n"
" const int height_idx=get_global_id(1);\n"
" const int batch_channel_idx=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(width_local_idx,height_idx,batch_channel_idx);\n"
" \n"
" const int width_idx=get_group_id(0);\n"
" const int batch_idx=batch_channel_idx/outputChannelBlock;\n"
" const int channel_idx=batch_channel_idx % outputChannelBlock;\n"
" \n"
" const int bh=batch_idx*inputHeight;\n"
" const int wc=channel_idx*inputWidth+width_idx;\n"
" const int lid=get_local_id(0);\n"
" INPUT_TYPE_I4 local sum[LOCAL_SIZE];\n"
" INPUT_TYPE_I4 out=(INPUT_TYPE_I4)VALUE;\n"
" for(int i=lid; i<inputHeight; i+=LOCAL_SIZE){\n"
" INPUT_TYPE_I4 in=RI_DATA(input,SAMPLER,(int2)(wc,bh+i));\n"
" out=OPERATE(out,in);\n"
" }\n"
" sum[lid]=out;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=OPERATE(sum[lid],sum[lid+i]);\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" out=sum[0];\n"
"#else\n"
" const int width_idx=get_global_id(0);\n"
" const int height_idx=get_global_id(1);\n"
" const int batch_channel_idx=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(width_idx,height_idx,batch_channel_idx);\n"
" \n"
" const int batch_idx=batch_channel_idx/outputChannelBlock;\n"
" const int channel_idx=batch_channel_idx % outputChannelBlock;\n"
" \n"
" const int bh=batch_idx*inputHeight;\n"
" const int wc=channel_idx*inputWidth+width_idx;\n"
" INPUT_TYPE_I4 out=(INPUT_TYPE_I4)VALUE;\n"
" for(int i=0; i<inputHeight; ++i){\n"
" INPUT_TYPE_I4 in=RI_DATA(input,SAMPLER,(int2)(wc,bh+i));\n"
" out=OPERATE(out,in);\n"
" }\n"
"#endif\n"
" \n"
"#ifdef GET_AVG\n"
" out=out/inputHeight;\n"
"#endif\n"
" WI_DATA(output,(int2)(wc,batch_idx),CONVERT_OUTPUT_I4(out));\n"
"}\n"
"__kernel void reduct_channel(GLOBAL_SIZE_3_DIMS\n"
" __read_only image2d_t input,\n"
" __write_only image2d_t output,\n"
" __private const int inputWidth,\n"
" __private const int inputHeight,\n"
" __private const int inputChannel,\n"
" __private const int inputBatch,\n"
" __private const int inputChannelBlock,\n"
" __private const int oututWidth,\n"
" __private const int outputHeight,\n"
" __private const int outputChannel,\n"
" __private const int outputChannelBlock\n"
" ) {\n"
"#if LOCAL_SIZE>0\n"
" const int width_local_idx=get_global_id(0);\n"
" const int height_idx=get_global_id(1);\n"
" const int batch_idx=get_global_id(2);\n"
" \n"
" DEAL_NON_UNIFORM_DIM3(width_local_idx,height_idx,batch_idx);\n"
" const int width_idx=get_group_id(0);\n"
" \n"
" const int bh=batch_idx*inputHeight+height_idx;\n"
" const int wc=width_idx;\n"
" int remain=inputChannel-(inputChannelBlock-1)*4;\n"
" const int lid=get_local_id(0);\n"
" INPUT_TYPE_I local sum[LOCAL_SIZE];\n"
" INPUT_TYPE_I4 out=(INPUT_TYPE_I4)VALUE;\n"
" INPUT_TYPE_I4 in;\n"
" INPUT_TYPE_I *inPtr=(INPUT_TYPE_I*)&in;\n"
" for(int i=lid; i<inputChannelBlock-1; i += LOCAL_SIZE){\n"
" in=RI_DATA(input,SAMPLER,(int2)(i*inputWidth+wc,bh));\n"
" out=OPERATE(out,in);\n"
" }\n"
" out.x=OPERATE(out.x,out.y);\n"
" out.x=OPERATE(out.x,out.z);\n"
" out.x=OPERATE(out.x,out.w);\n"
" sum[lid]=out.x;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=OPERATE(sum[lid],sum[lid+i]);\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" out.x=sum[0];\n"
" in=RI_DATA(input,SAMPLER,(int2)((inputChannelBlock-1)*inputWidth+wc,bh));\n"
" for(int j=0; j<remain; ++j){\n"
" out.x=OPERATE(out.x,inPtr[j]);\n"
" }\n"
"#ifdef GET_AVG\n"
" out.x=out.x/inputChannel;\n"
"#endif\n"
" WI_DATA(output,(int2)(wc,bh),(OUTPUT_TYPE_I4)(out.x,0,0,0));\n"
" \n"
"#else\n"
" const int width_idx=get_global_id(0);\n"
" const int height_idx=get_global_id(1);\n"
" const int batch_idx=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(width_idx,height_idx,batch_idx);\n"
" \n"
" const int bh=batch_idx*inputHeight+height_idx;\n"
" const int wc=width_idx;\n"
" int remain=inputChannel-(inputChannelBlock-1)*4;\n"
" \n"
" INPUT_TYPE_I out=(INPUT_TYPE_I)VALUE;\n"
" INPUT_TYPE_I4 in;\n"
" INPUT_TYPE_I *inPtr=(INPUT_TYPE_I*)&in;\n"
" \n"
" for(int i=0; i<inputChannelBlock-1; ++i){\n"
" in=RI_DATA(input,SAMPLER,(int2)(i*inputWidth+wc,bh));\n"
" for(int j=0; j<4; ++j){\n"
" out=OPERATE(out,inPtr[j]);\n"
" }\n"
" }\n"
" in=RI_DATA(input,SAMPLER,(int2)((inputChannelBlock-1)*inputWidth+wc,bh));\n"
" for(int j=0; j<remain; ++j){\n"
" out=OPERATE(out,inPtr[j]);\n"
" }\n"
"#ifdef GET_AVG\n"
" out=out/inputChannel;\n"
"#endif\n"
" WI_DATA(output,(int2)(wc,bh),(OUTPUT_TYPE_I4)(out,0,0,0));\n"
"#endif\n"
"}\n"
"__kernel void reduct_batch(GLOBAL_SIZE_3_DIMS\n"
" __read_only image2d_t input,\n"
" __write_only image2d_t output,\n"
" __private const int inputWidth,\n"
" __private const int inputHeight,\n"
" __private const int inputChannel,\n"
" __private const int inputBatch,\n"
" __private const int inputChannelBlock,\n"
" __private const int oututWidth,\n"
" __private const int outputHeight,\n"
" __private const int outputChannel,\n"
" __private const int outputChannelBlock\n"
" ) {\n"
"#if LOCAL_SIZE>0\n"
" const int width_local_idx=get_global_id(0);\n"
" const int height_idx=get_global_id(1);\n"
" const int channel_idx=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(width_local_idx,height_idx,channel_idx);\n"
" const int width_idx=get_group_id(0);\n"
" \n"
" const int bh=height_idx;\n"
" const int wc=channel_idx*inputWidth+width_idx;\n"
" int batchOffset=inputChannelBlock*inputHeight*inputWidth;\n"
" const int lid=get_local_id(0);\n"
" INPUT_TYPE_I4 local sum[LOCAL_SIZE];\n"
" INPUT_TYPE_I4 out=(INPUT_TYPE_I4)VALUE;\n"
" for(int i=lid; i<inputBatch; i+=LOCAL_SIZE){\n"
" INPUT_TYPE_I4 in=RI_DATA(input,SAMPLER,(int2)(wc,i*inputHeight+bh));\n"
" out=OPERATE(out,in);\n"
" }\n"
" sum[lid]=out;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=LOCAL_SIZE/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=OPERATE(sum[lid],sum[lid+i]);\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" out=sum[0];\n"
"#ifdef GET_AVG\n"
" out=out/inputBatch;\n"
"#endif\n"
" WI_DATA(output,(int2)(wc,bh),CONVERT_OUTPUT_I4(out));\n"
"#else\n"
" const int width_idx=get_global_id(0);\n"
" const int height_idx=get_global_id(1);\n"
" const int channel_idx=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(width_idx,height_idx,channel_idx);\n"
" \n"
" const int bh=height_idx;\n"
" const int wc=channel_idx*inputWidth+width_idx;\n"
" int batchOffset=inputChannelBlock*inputHeight*inputWidth;\n"
" INPUT_TYPE_I4 out=(INPUT_TYPE_I4)VALUE;\n"
" for(int i=0; i<inputBatch; ++i){\n"
" INPUT_TYPE_I4 in=RI_DATA(input,SAMPLER,(int2)(wc,i*inputHeight+bh));\n"
" out=OPERATE(out,in);\n"
" }\n"
"#ifdef GET_AVG\n"
" out=out/inputBatch;\n"
"#endif\n"
" WI_DATA(output,(int2)(wc,bh),CONVERT_OUTPUT_I4(out));\n"
"#endif\n"
"}\n"
;
}
